diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..c69283ec --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +aml +target +server/transformers +server/flash-attention diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 00000000..24ac3cbe --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,67 @@ +name: "\U0001F41B Bug Report" +description: Submit a bug report to help us improve text-generation-inference +body: + - type: textarea + id: system-info + attributes: + label: System Info + description: | + Please share your system info with us (`text-generation-launcher --env` if installed locally). + The full command line used that causes issues: + OS version: + Rust version (if self-compiling, `cargo version`): + Model being used (`curl 127.0.0.1:8080/info | jq`): + If local model please explicit the kind of model and/or equivalents. + Hardware used (GPUs, how many, on which cloud) (`nvidia-smi`): + Deployment specificities (Kubernetes, EKS, AKS, any particular deployments): + The current version being used: + + placeholder: text-generation-inference version, platform, python version, ... + validations: + required: true + + - type: checkboxes + id: information-scripts-examples + attributes: + label: Information + description: 'The problem arises when using:' + options: + - label: "Docker" + - label: "The CLI directly" + + - type: checkboxes + id: information-tasks + attributes: + label: Tasks + description: "The thing I am working on is:" + options: + - label: "An officially supported command" + - label: "My own modifications" + + - type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction + description: | + Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. + If you have code snippets, error messages, stack traces please provide them here as well. + Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. + + placeholder: | + Steps to reproduce the behavior: + + 1. + 2. + 3. + + + - type: textarea + id: expected-behavior + validations: + required: true + attributes: + label: Expected behavior + description: "A clear and concise description of what you would expect to happen." diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..e6477729 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,2 @@ +blank_issues_enabled: true +version: 2.1 diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 00000000..f1a9135c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,31 @@ +name: "\U0001F680 Feature request" +description: Submit a proposal/request for a new text-generation-inference feature +labels: [ "feature" ] +body: + - type: textarea + id: feature-request + validations: + required: true + attributes: + label: Feature request + description: | + A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. + + - type: textarea + id: motivation + validations: + required: true + attributes: + label: Motivation + description: | + Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. + + + - type: textarea + id: contribution + validations: + required: true + attributes: + label: Your contribution + description: | + Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md) diff --git a/.github/ISSUE_TEMPLATE/new-model-addition.yml b/.github/ISSUE_TEMPLATE/new-model-addition.yml new file mode 100644 index 00000000..2f3476d3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/new-model-addition.yml @@ -0,0 +1,31 @@ +name: "\U0001F31F New model addition" +description: Submit a proposal/request to implement a new model +labels: [ "New model" ] + +body: + - type: textarea + id: description-request + validations: + required: true + attributes: + label: Model description + description: | + Put any and all important information relative to the model + + - type: checkboxes + id: information-tasks + attributes: + label: Open source status + description: | + Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `transformers`. + options: + - label: "The model implementation is available" + - label: "The model weights are available" + + - type: textarea + id: additional-info + attributes: + label: Provide useful links for the implementation + description: | + Please provide information regarding the implementation, the weights, and the authors. + Please mention the authors by @gh-username if you're aware of their usernames. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..ad5b98ab --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,40 @@ +# What does this PR do? + + + + + +Fixes # (issue) + + +## Before submitting +- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). +- [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), + Pull Request section? +- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link + to it if that's the case. +- [ ] Did you make sure to update the documentation with your changes? Here are the + [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and + [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). +- [ ] Did you write any new necessary tests? + + +## Who can review? + +Anyone in the community is free to review the PR once the tests have passed. Feel free to tag +members/contributors who may be interested in your PR. + + diff --git a/.github/workflows/autodocs.yml b/.github/workflows/autodocs.yml new file mode 100644 index 00000000..48ed01e2 --- /dev/null +++ b/.github/workflows/autodocs.yml @@ -0,0 +1,20 @@ +name: Automatic Documentation for Launcher + +on: + pull_request: + +jobs: + update_docs: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Install Launcher + id: install-launcher + run: cargo install --path launcher/ + - name: Check launcher Docs are up-to-date + run: | + echo text-generation-launcher --help + python update_doc.py --check diff --git a/.github/workflows/build-kvrun.yml b/.github/workflows/build-kvrun.yml new file mode 100644 index 00000000..92626f74 --- /dev/null +++ b/.github/workflows/build-kvrun.yml @@ -0,0 +1,100 @@ +name: Build and push kv.run docker image + +on: + workflow_dispatch: + +jobs: + build-and-push-image: + runs-on: [self-hosted, Linux, X64] + + concurrency: + group: ${{ github.workflow }}-build-and-push-image-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + + permissions: + contents: write + packages: write + # This is used to complete the identity challenge + # with sigstore/fulcio when running outside of PRs. + id-token: write + security-events: write + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@v4.3.0 + with: + flavor: | + latest=auto + images: | + ghcr.io/${{env.GITHUB_REPOSITORY}} + tags: | + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} + - name: Build and push Docker image + id: build-and-push + uses: docker/build-push-action@v5 + with: + context: . + file: Dockerfile_kvrun + push: true + platforms: 'linux/amd64' + build-args: | + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + #cache-from: type=gha + #cache-to: type=gha,mode=max + +# integration-tests: +# runs-on: [self-hosted, Linux, X64] +# +# concurrency: +# group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} +# cancel-in-progress: true +# +# needs: +# - build-and-push-image # Wait for the docker image to be built +# +# env: +# DOCKER_VOLUME: /cache +# +# steps: +# - uses: actions/checkout@v2 +# - name: Inject slug/short variables +# uses: rlespinasse/github-slug-action@v4.4.1 +# - name: Set up Python +# uses: actions/setup-python@v4 +# with: +# python-version: 3.10.14 +# - name: Prepare disks +# run: | +# sudo mkfs -t ext4 /dev/nvme1n1 +# sudo mkdir ${{ env.DOCKER_VOLUME }} +# sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} +# - name: Install +# run: | +# make install-integration-tests +# - name: Run tests +# run: | +# export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} +# export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} +# pytest -s -vv integration-tests diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml new file mode 100644 index 00000000..4d0b19a3 --- /dev/null +++ b/.github/workflows/build_documentation.yml @@ -0,0 +1,20 @@ +name: Build documentation + +on: + push: + paths: + - "docs/source/**" + branches: + - main + - doc-builder* + - v*-release + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + with: + commit_sha: ${{ github.sha }} + package: text-generation-inference + additional_args: --not_python_module + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml new file mode 100644 index 00000000..a5ce39a5 --- /dev/null +++ b/.github/workflows/build_pr_documentation.yml @@ -0,0 +1,19 @@ +name: Build PR Documentation + +on: + pull_request: + paths: + - "docs/source/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + with: + commit_sha: ${{ github.event.pull_request.head.sha }} + pr_number: ${{ github.event.number }} + package: text-generation-inference + additional_args: --not_python_module diff --git a/.github/workflows/client-tests.yaml b/.github/workflows/client-tests.yaml new file mode 100644 index 00000000..ef7c217c --- /dev/null +++ b/.github/workflows/client-tests.yaml @@ -0,0 +1,26 @@ +name: Python Client Tests + +on: + pull_request: + paths: + - ".github/workflows/client-tests.yaml" + - "clients/python/**" + +jobs: + run_tests: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.9 + - name: Install + run: | + cd clients/python && pip install . + - name: Run tests + run: | + pip install pytest pytest-asyncio + export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + make python-client-tests diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml new file mode 100644 index 00000000..fd22e395 --- /dev/null +++ b/.github/workflows/load_test.yaml @@ -0,0 +1,108 @@ +name: Nightly load test + +on: + schedule: + - cron: '0 0 * * 1-5' + + pull_request: + paths: + - ".github/workflows/load_test.yaml" + branches: + - 'main' + +jobs: + start-runner: + name: Start self-hosted EC2 runner + runs-on: ubuntu-latest + env: + AWS_REGION: eu-central-1 + EC2_AMI_ID: ami-0ab09c07cfd194259 + EC2_INSTANCE_TYPE: g5.12xlarge + EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326 + EC2_SECURITY_GROUP: sg-072f92ae3082936c6 + outputs: + label: ${{ steps.start-ec2-runner.outputs.label }} + ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ env.AWS_REGION }} + - name: Start EC2 runner + id: start-ec2-runner + uses: philschmid/philschmid-ec2-github-runner@main + with: + mode: start + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + ec2-image-id: ${{ env.EC2_AMI_ID }} + ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} + subnet-id: ${{ env.EC2_SUBNET_ID }} + security-group-id: ${{ env.EC2_SECURITY_GROUP }} + aws-resource-tags: > # optional, requires additional permissions + [ + {"Key": "Name", "Value": "ec2-tgi-github-runner"}, + {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} + ] + + load-tests: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + needs: start-runner # required to start the main job when the runner is ready + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + env: + DOCKER_VOLUME: /cache + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Prepare disks + run: | + sudo mkfs -t ext4 /dev/nvme1n1 + sudo mkdir ${{ env.DOCKER_VOLUME }} + sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} + + - name: Install k6 + run: | + curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1 + + - name: Start starcoder + run: | + docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v ${{ env.DOCKER_VOLUME }}:/data -e HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 + sleep 10 + wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health + + - name: Run k6 + run: | + ./k6 run load_tests/starcoder_load.js + + - name: Stop starcoder + if: ${{ always() }} + run: | + docker stop tgi-starcoder || true + + stop-runner: + name: Stop self-hosted EC2 runner + needs: + - start-runner + - load-tests + runs-on: ubuntu-latest + env: + AWS_REGION: eu-central-1 + if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ env.AWS_REGION }} + - name: Stop EC2 runner + uses: philschmid/philschmid-ec2-github-runner@main + with: + mode: stop + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + label: ${{ needs.start-runner.outputs.label }} + ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..a5e50a79 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,14 @@ +name: 'Close stale issues and PRs' +on: + schedule: + - cron: '30 1 * * *' + +jobs: + stale: + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v8 + with: + stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.' + days-before-stale: 30 + days-before-close: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..2ce3da89 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,98 @@ +name: kv.run server tests + +on: + workflow_dispatch: + pull_request: + paths: + - ".github/workflows/tests-kv-run.yaml" + - "server/**" + - "proto/**" + - "router/**" + - "launcher/**" + - "Cargo.lock" + - "rust-toolchain.toml" + push: + paths: + - ".github/workflows/tests-kv-run.yaml" + - "server/**" + - "proto/**" + - "router/**" + - "launcher/**" + - "Cargo.lock" + - "rust-toolchain.toml" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + run_tests: + runs-on: [self-hosted, Linux, X64] + + env: + SCCACHE_GHA_ENABLED: "on" + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.10.14 + env: + AGENT_TOOLSDIRECTORY: /opt/hostedtoolache + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + # Released on: 13 June, 2024 + # https://releases.rs/docs/1.79.0/ + toolchain: 1.79.0 + override: true + components: rustfmt, clippy + - name: Install Protoc + uses: arduino/setup-protoc@v1 + - name: Clean unused files + run: | + sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android + sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET + - name: Install sccache + uses: mozilla-actions/sccache-action@v0.0.4 + with: + version: "v0.3.3" + - name: configure sccache + uses: actions/github-script@v6 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}'); + core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-'); + core.exportVariable('RUSTC_WRAPPER', process.env.SCCACHE_PATH || ''); + - name: cargo registry cache + uses: actions/cache@v3 + with: + key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }} + restore-keys: | + cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}- + cargo-${{ runner.os }}- + path: | + ~/.cargo/registry + ~/.cargo/git + - name: Install + run: | + make install + - name: Run server tests + run: | + pip install pytest + export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + CUDA_VISIBLE_DEVICES=-1 pytest -s -vv server/tests + - name: Pre-commit checks + run: | + pip install pre-commit + pre-commit install + pre-commit run --all-files + - name: Run Rust tests + run: | + cargo test + - name: sccache stats + run: | + ${SCCACHE_PATH} --show-stats diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml new file mode 100644 index 00000000..ae00bb51 --- /dev/null +++ b/.github/workflows/upload_pr_documentation.yml @@ -0,0 +1,16 @@ +name: Upload PR Documentation + +on: + workflow_run: + workflows: ["Build PR Documentation"] + types: + - completed + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: text-generation-inference + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} diff --git a/.gitignore b/.gitignore index fa917796..5a3c749e 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,19 @@ _build_meta.py lora_weights/ .vscode/ .venv +.idea +target +router/tokenizer.json +*__pycache__* + +# ROCm auto-generated files +*.hip +server/exllamav2_kernels/exllamav2_kernels/hip/ +server/exllama_kernels/exllama_kernels/hip/ +server/exllama_kernels/exllama_kernels/hip_func/ +*_hip.cuh +server/exllama_kernels/exllama_kernels/hip_buffers.cuh +server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp + +data/ +load_tests/*.json diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 87db31d0..00000000 --- a/.gitmodules +++ /dev/null @@ -1,10 +0,0 @@ -[submodule "server/third_party/cutlass"] - path = server/third_party/cutlass - url = https://github.com/NVIDIA/cutlass -[submodule "server/third_party/flashinfer"] - path = server/third_party/flashinfer - url = https://github.com/flashinfer-ai/flashinfer -[submodule "third_party/text-generation-inference"] - path = third_party/text-generation-inference - url = https://github.com/huggingface/text-generation-inference - branch = 8f22cb9 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..45bc07a5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + exclude: docs/source/basic_tutorials/launcher.md +- repo: https://github.com/psf/black + rev: 24.2.0 + hooks: + - id: black +- repo: https://github.com/doublify/pre-commit-rust + rev: v1.0 + hooks: + - id: fmt + - id: cargo-check + - id: clippy diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 00000000..9529fd42 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,4814 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "serde", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "aligned-vec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" + +[[package]] +name = "anstream" +version = "0.6.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" + +[[package]] +name = "anstyle-parse" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a64c907d4e79225ac72e2a354c9ce84d50ebb4586dee56c82b3ee73004f537f5" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + +[[package]] +name = "arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" + +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "arg_enum_proc_macro" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + +[[package]] +name = "async-rustls" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93b21a03b7c21702a0110f9f8d228763a533570deb376119042dabf33c37a01a" +dependencies = [ + "futures-io", + "rustls 0.20.9", + "webpki", +] + +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "async-trait" +version = "0.1.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "av1-grain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6678909d8c5d46a42abcf571271e15fdbc0a225e3646cf23762cd415046c78bf" +dependencies = [ + "anyhow", + "arrayvec", + "log", + "nom", + "num-rational", + "v_frame", +] + +[[package]] +name = "average" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c309b1c7fca12ebeec3ecba29ea917b3a4cb458ccf504df68bb4d8a0ca565a00" +dependencies = [ + "easy-cast", + "float-ord", + "num-traits", +] + +[[package]] +name = "avif-serialize" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876c75a42f6364451a033496a14c44bffe41f5f4a8236f697391f11024e596d2" +dependencies = [ + "arrayvec", +] + +[[package]] +name = "awaitdrop" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "771051cdc7eec2dc1b23fbf870bb7fbb89136fe374227c875e377f1eed99a429" +dependencies = [ + "futures", + "generational-arena", + "parking_lot", + "slotmap", +] + +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core 0.3.4", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 0.1.2", + "tokio", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +dependencies = [ + "async-trait", + "axum-core 0.4.3", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "hyper 1.3.1", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 0.1.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-tracing-opentelemetry" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08" +dependencies = [ + "axum 0.7.5", + "futures-core", + "futures-util", + "http 1.1.0", + "opentelemetry 0.21.0", + "pin-project-lite", + "tower", + "tracing", + "tracing-opentelemetry 0.22.0", + "tracing-opentelemetry-instrumentation-sdk", +] + +[[package]] +name = "backtrace" +version = "0.3.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bit_field" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + +[[package]] +name = "bitstream-io" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c12d1856e42f0d817a835fe55853957c85c8c8a470114029143d3f12671446e" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "built" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6a6c0b39c38fd754ac338b00a88066436389c0f029da5d37d1e01091d9b7c17" + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "bytecount" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" + +[[package]] +name = "bytemuck" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + +[[package]] +name = "bytes" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" + +[[package]] +name = "camino" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0ec6b951b160caa93cc0c7b209e5a3bff7aae9062213451ac99493cd844c239" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo-platform" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037" +dependencies = [ + "camino", + "cargo-platform", + "semver", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "cassowary" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" + +[[package]] +name = "cc" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] + +[[package]] +name = "cfg-expr" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02" +dependencies = [ + "smallvec", + "target-lexicon", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + +[[package]] +name = "clap" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "clap_lex" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" + +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + +[[package]] +name = "colorchoice" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" + +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.52.0", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.5.0", + "crossterm_winapi", + "libc", + "mio", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "ctrlc" +version = "3.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +dependencies = [ + "nix", + "windows-sys 0.52.0", +] + +[[package]] +name = "darling" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.66", +] + +[[package]] +name = "darling_macro" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "derive_builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +dependencies = [ + "derive_builder_core", + "syn 2.0.66", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "easy-cast" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6" +dependencies = [ + "libm", +] + +[[package]] +name = "either" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + +[[package]] +name = "exr" +version = "1.72.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" +dependencies = [ + "bit_field", + "flume", + "half", + "lebe", + "miniz_oxide", + "rayon-core", + "smallvec", + "zune-inflate", +] + +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + +[[package]] +name = "fdeflate" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "float-ord" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" + +[[package]] +name = "float_eq" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" + +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "spin 0.9.8", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fraction" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678" +dependencies = [ + "lazy_static", + "num", +] + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "generational-arena" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877e94aff08e743b651baaea359664321055749b398adff8740a7399af7796e7" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "gif" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb2d69b19215e18bb912fa30f7ce15846e301408695e44e0ef719f1da9e19f2" +dependencies = [ + "color_quant", + "weezl", +] + +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "grpc-metadata" +version = "0.1.0" +dependencies = [ + "opentelemetry 0.20.0", + "tonic 0.10.2", + "tracing", + "tracing-opentelemetry 0.21.0", +] + +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap 2.2.6", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ff8ae62cd3a9102e5637afc8452c55acf3844001bd5374e0b0bd7b6616c038" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hf-hub" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" +dependencies = [ + "dirs", + "futures", + "indicatif", + "log", + "native-tls", + "num_cpus", + "rand", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "ureq", +] + +[[package]] +name = "hostname" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" +dependencies = [ + "libc", + "match_cfg", + "winapi", +] + +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" +dependencies = [ + "bytes", + "futures-core", + "http 1.1.0", + "http-body 1.0.0", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "0.14.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.28", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper 0.14.28", + "native-tls", + "tokio", + "tokio-native-tls", +] + +[[package]] +name = "hyper-util" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d8d52be92d09acc2e01dddb7fde3ad983fc6489c7db4837e605bc3fca4cb63e" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "hyper 1.3.1", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "image" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "exr", + "gif", + "image-webp", + "num-traits", + "png", + "qoi", + "ravif", + "rayon", + "rgb", + "tiff", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-webp" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d730b085583c4d789dfd07fdcf185be59501666a90c97c40162b37e4fdad272d" +dependencies = [ + "byteorder-lite", + "thiserror", +] + +[[package]] +name = "imgref" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44feda355f4159a7c757171a77de25daf6411e217b4cabd03bd6650690468126" + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown 0.14.5", + "serde", +] + +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "init-tracing-opentelemetry" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17" +dependencies = [ + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "thiserror", + "tracing", + "tracing-opentelemetry 0.21.0", +] + +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "interpolate_name" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "ipnet" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" + +[[package]] +name = "iso8601" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153" +dependencies = [ + "nom", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + +[[package]] +name = "jpeg-decoder" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "jsonschema" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978" +dependencies = [ + "ahash", + "anyhow", + "base64 0.21.7", + "bytecount", + "clap", + "fancy-regex", + "fraction", + "getrandom", + "iso8601", + "itoa", + "memchr", + "num-cmp", + "once_cell", + "parking_lot", + "percent-encoding", + "regex", + "reqwest", + "serde", + "serde_json", + "time", + "url", + "uuid", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "lebe" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "libfuzzer-sys" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96cfd5557eb82f2b83fed4955246c988d331975a002961b07c81584d107e7f7" +dependencies = [ + "arbitrary", + "cc", + "once_cell", +] + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.5.0", + "libc", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "loop9" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062" +dependencies = [ + "imgref", +] + +[[package]] +name = "mach2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" +dependencies = [ + "libc", +] + +[[package]] +name = "macro_rules_attribute" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" + +[[package]] +name = "match_cfg" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" + +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "maybe-rayon" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" +dependencies = [ + "cfg-if", + "rayon", +] + +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + +[[package]] +name = "metrics" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" +dependencies = [ + "ahash", + "metrics-macros", + "portable-atomic", +] + +[[package]] +name = "metrics-exporter-prometheus" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d4fa7ce7c4862db464a37b0b31d89bca874562f034bd7993895572783d02950" +dependencies = [ + "base64 0.21.7", + "hyper 0.14.28", + "indexmap 1.9.3", + "ipnet", + "metrics", + "metrics-util", + "quanta", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "metrics-macros" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "metrics-util" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4de2ed6e491ed114b40b732e4d1659a9d53992ebd87490c44a6ffe23739d973e" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", + "hashbrown 0.13.1", + "metrics", + "num_cpus", + "quanta", + "sketches-ddsketch", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + +[[package]] +name = "minijinja" +version = "1.0.12" +source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28" +dependencies = [ + "serde", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +dependencies = [ + "adler", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys 0.48.0", +] + +[[package]] +name = "monostate" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "multimap" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" + +[[package]] +name = "muxado" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e92b89ac3127251efde6f5a9586e5aae99468d06fcf9f133b377f58d5ed66446" +dependencies = [ + "async-trait", + "awaitdrop", + "bitflags 1.3.2", + "bytes", + "futures", + "pin-project", + "rand", + "thiserror", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + +[[package]] +name = "ngrok" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1454b1edbc5f2c8ff3242c237cb84388b50eced8eb26b4204e49698ed6511784" +dependencies = [ + "arc-swap", + "async-rustls", + "async-trait", + "awaitdrop", + "axum 0.6.20", + "base64 0.13.1", + "bytes", + "futures", + "hostname", + "hyper 0.14.28", + "muxado", + "once_cell", + "parking_lot", + "regex", + "rustls-pemfile", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-retry", + "tokio-util", + "tracing", + "windows-sys 0.45.0", +] + +[[package]] +name = "nix" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" +dependencies = [ + "bitflags 2.5.0", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "noop_proc_macro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "num_threads" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" +dependencies = [ + "libc", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "onig" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" +dependencies = [ + "bitflags 1.3.2", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "openssl" +version = "0.10.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +dependencies = [ + "bitflags 2.5.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "opentelemetry" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54" +dependencies = [ + "opentelemetry_api", + "opentelemetry_sdk 0.20.0", +] + +[[package]] +name = "opentelemetry" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" +dependencies = [ + "futures-core", + "futures-sink", + "indexmap 2.2.6", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275" +dependencies = [ + "async-trait", + "futures-core", + "http 0.2.12", + "opentelemetry-proto", + "opentelemetry-semantic-conventions", + "opentelemetry_api", + "opentelemetry_sdk 0.20.0", + "prost 0.11.9", + "thiserror", + "tokio", + "tonic 0.9.2", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb" +dependencies = [ + "opentelemetry_api", + "opentelemetry_sdk 0.20.0", + "prost 0.11.9", + "tonic 0.9.2", +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269" +dependencies = [ + "opentelemetry 0.20.0", +] + +[[package]] +name = "opentelemetry_api" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a81f725323db1b1206ca3da8bb19874bbd3f57c3bcd59471bfb04525b265b9b" +dependencies = [ + "futures-channel", + "futures-util", + "indexmap 1.9.3", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa8e705a0612d48139799fcbaba0d4a90f06277153e43dd2bdc16c6f0edd8026" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "once_cell", + "opentelemetry_api", + "ordered-float 3.9.2", + "percent-encoding", + "rand", + "regex", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f16aec8a98a457a52664d69e0091bac3a0abd18ead9b641cb00202ba4e0efe4" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.21.0", + "ordered-float 4.2.0", + "percent-encoding", + "rand", + "thiserror", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "ordered-float" +version = "3.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1e1c390732d15f1d48471625cd92d154e66db2c56645e29a9cd26f4699f72dc" +dependencies = [ + "num-traits", +] + +[[package]] +name = "ordered-float" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +dependencies = [ + "num-traits", +] + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "papergrid" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2ccbe15f2b6db62f9a9871642746427e297b0ceb85f9a7f1ee5ff47d184d0c8" +dependencies = [ + "bytecount", + "fnv", + "unicode-width", +] + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.5", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.2.6", +] + +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "png" +version = "0.17.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" +dependencies = [ + "bitflags 1.3.2", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "prettyplease" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +dependencies = [ + "proc-macro2", + "syn 2.0.66", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "profiling" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d84d1d7a6ac92673717f9f6d1518374ef257669c24ebc5ac25d5033828be58" +dependencies = [ + "profiling-procmacros", +] + +[[package]] +name = "profiling-procmacros" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" +dependencies = [ + "quote", + "syn 2.0.66", +] + +[[package]] +name = "prost" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" +dependencies = [ + "bytes", + "prost-derive 0.11.9", +] + +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive 0.12.6", +] + +[[package]] +name = "prost-build" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" +dependencies = [ + "bytes", + "heck 0.5.0", + "itertools 0.12.1", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost 0.12.6", + "prost-types", + "regex", + "syn 2.0.66", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +dependencies = [ + "anyhow", + "itertools 0.10.5", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools 0.12.1", + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "prost-types" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +dependencies = [ + "prost 0.12.6", +] + +[[package]] +name = "qoi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "quanta" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +dependencies = [ + "crossbeam-utils", + "libc", + "mach2", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "ratatui" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad" +dependencies = [ + "bitflags 2.5.0", + "cassowary", + "crossterm", + "indoc", + "itertools 0.11.0", + "paste", + "strum", + "unicode-segmentation", + "unicode-width", +] + +[[package]] +name = "rav1e" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9" +dependencies = [ + "arbitrary", + "arg_enum_proc_macro", + "arrayvec", + "av1-grain", + "bitstream-io", + "built", + "cfg-if", + "interpolate_name", + "itertools 0.12.1", + "libc", + "libfuzzer-sys", + "log", + "maybe-rayon", + "new_debug_unreachable", + "noop_proc_macro", + "num-derive", + "num-traits", + "once_cell", + "paste", + "profiling", + "rand", + "rand_chacha", + "simd_helpers", + "system-deps", + "thiserror", + "v_frame", + "wasm-bindgen", +] + +[[package]] +name = "ravif" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc13288f5ab39e6d7c9d501759712e6969fcc9734220846fc9ed26cae2cc4234" +dependencies = [ + "avif-serialize", + "imgref", + "loop9", + "quick-error", + "rav1e", + "rayon", + "rgb", +] + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags 2.5.0", +] + +[[package]] +name = "redox_users" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.4.6", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + +[[package]] +name = "reqwest" +version = "0.11.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" +dependencies = [ + "base64 0.21.7", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", + "hyper-tls", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 0.1.2", + "system-configuration", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + +[[package]] +name = "rgb" +version = "0.8.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted 0.7.1", + "web-sys", + "winapi", +] + +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys 0.52.0", +] + +[[package]] +name = "rust-embed" +version = "8.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19549741604902eb99a7ed0ee177a0663ee1eda51a29f71401f166e47e77806a" +dependencies = [ + "rust-embed-impl", + "rust-embed-utils", + "walkdir", +] + +[[package]] +name = "rust-embed-impl" +version = "8.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb9f96e283ec64401f30d3df8ee2aaeb2561f34c824381efa24a35f79bf40ee4" +dependencies = [ + "proc-macro2", + "quote", + "rust-embed-utils", + "syn 2.0.66", + "walkdir", +] + +[[package]] +name = "rust-embed-utils" +version = "8.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c74a686185620830701348de757fd36bef4aa9680fd23c49fc539ddcc1af32" +dependencies = [ + "sha2", + "walkdir", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags 2.5.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" +dependencies = [ + "log", + "ring 0.16.20", + "sct", + "webpki", +] + +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring 0.17.8", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + +[[package]] +name = "rustls-pki-types" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" + +[[package]] +name = "rustls-webpki" +version = "0.102.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +dependencies = [ + "ring 0.17.8", + "rustls-pki-types", + "untrusted 0.9.0", +] + +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring 0.17.8", + "untrusted 0.9.0", +] + +[[package]] +name = "security-framework" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +dependencies = [ + "bitflags 2.5.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +dependencies = [ + "serde", +] + +[[package]] +name = "serde" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "serde_json" +version = "1.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + +[[package]] +name = "serde_spanned" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "simd_helpers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6" +dependencies = [ + "quote", +] + +[[package]] +name = "sketches-ddsketch" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "slotmap" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" +dependencies = [ + "version_check", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.66", +] + +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" + +[[package]] +name = "sysinfo" +version = "0.30.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "windows", +] + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "system-deps" +version = "6.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" +dependencies = [ + "cfg-expr", + "heck 0.5.0", + "pkg-config", + "toml", + "version-compare", +] + +[[package]] +name = "tabled" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfe9c3632da101aba5131ed63f9eed38665f8b3c68703a6bb18124835c1a5d22" +dependencies = [ + "papergrid", + "tabled_derive", + "unicode-width", +] + +[[package]] +name = "tabled_derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f688a08b54f4f02f0a3c382aefdb7884d3d69609f785bd253dc033243e3fe4" +dependencies = [ + "heck 0.4.1", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "target-lexicon" +version = "0.12.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" + +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + +[[package]] +name = "text-generation-benchmark" +version = "2.0.5-dev0" +dependencies = [ + "average", + "clap", + "crossterm", + "float-ord", + "hf-hub", + "ratatui", + "serde", + "serde_json", + "tabled", + "text-generation-client", + "thiserror", + "tokenizers", + "tokio", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "text-generation-client" +version = "2.0.5-dev0" +dependencies = [ + "futures", + "grpc-metadata", + "prost 0.12.6", + "prost-build", + "thiserror", + "tokio", + "tonic 0.10.2", + "tonic-build", + "tower", + "tracing", +] + +[[package]] +name = "text-generation-launcher" +version = "2.0.5-dev0" +dependencies = [ + "bitstream-io", + "clap", + "ctrlc", + "float_eq", + "hf-hub", + "nix", + "once_cell", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tracing", + "tracing-subscriber", + "vergen", +] + +[[package]] +name = "text-generation-router" +version = "2.0.5-dev0" +dependencies = [ + "async-stream", + "axum 0.7.5", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "bitstream-io", + "clap", + "futures", + "futures-util", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "ngrok", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "text-generation-client", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", + "vergen", +] + +[[package]] +name = "thiserror" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "tiff" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e" +dependencies = [ + "flate2", + "jpeg-decoder", + "weezl", +] + +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "itoa", + "libc", + "num-conv", + "num_threads", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokenizers" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.3", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokio" +version = "1.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "toml" +version = "0.8.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4e43f8cc456c9704c851ae29c67e17ef65d2c30017c17a9765b89c382dc8bba" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c127785850e8c20836d49732ae6abfa47616e60bf9d9f57c43c250361a9db96c" +dependencies = [ + "indexmap 2.2.6", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "tonic" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" +dependencies = [ + "async-trait", + "axum 0.6.20", + "base64 0.21.7", + "bytes", + "futures-core", + "futures-util", + "h2", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost 0.11.9", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.6.20", + "base64 0.21.7", + "bytes", + "h2", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost 0.12.6", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d021fc044c18582b9a2408cd0dd05b1596e3ecdb5c4df822bb0183545683889" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 1.9.3", + "pin-project", + "pin-project-lite", + "rand", + "slab", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags 2.5.0", + "bytes", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f751112709b4e791d8ce53e32c4ed2d353565a795ce84da2285393f41557bdf2" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75327c6b667828ddc28f5e3f169036cb793c3f588d83bf0f262a7f062ffed3c8" +dependencies = [ + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry_sdk 0.20.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.1.4", + "tracing-subscriber", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c67ac25c5407e7b961fafc6f7e9aa5958fd297aada2d20fa2ae1737357e55596" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.21.0", + "opentelemetry_sdk 0.21.2", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time", +] + +[[package]] +name = "tracing-opentelemetry-instrumentation-sdk" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9920abb6a3ee3a2af7d30c9ff02900f8481935d36723c3da95cf807468218e8c" +dependencies = [ + "http 1.1.0", + "opentelemetry 0.21.0", + "tracing", + "tracing-opentelemetry 0.22.0", +] + +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-serde", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" + +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls 0.22.4", + "rustls-pki-types", + "rustls-webpki", + "serde", + "serde_json", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "utoipa" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" +dependencies = [ + "indexmap 2.2.6", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-gen" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bf0e16c02bc4bf5322ab65f10ab1149bdbcaa782cba66dc7057370a3f8190be" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "regex", + "syn 2.0.66", +] + +[[package]] +name = "utoipa-swagger-ui" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da" +dependencies = [ + "axum 0.7.5", + "mime_guess", + "regex", + "rust-embed", + "serde", + "serde_json", + "utoipa", + "zip", +] + +[[package]] +name = "uuid" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" + +[[package]] +name = "v_frame" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6f32aaa24bacd11e488aa9ba66369c7cd514885742c9fe08cfe85884db3e92b" +dependencies = [ + "aligned-vec", + "num-traits", + "wasm-bindgen", +] + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "vergen" +version = "8.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e27d6bdd219887a9eadd19e1c34f32e47fa332301184935c6d9bca26f3cca525" +dependencies = [ + "anyhow", + "cargo_metadata", + "cfg-if", + "regex", + "rustc_version", + "rustversion", + "sysinfo", + "time", +] + +[[package]] +name = "version-compare" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa30049b1c872b72c89866d458eae9f20380ab280ffd1b1e18df2d3e2d98cfe0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" +dependencies = [ + "ring 0.17.8", + "untrusted 0.9.0", +] + +[[package]] +name = "webpki-roots" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "weezl" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "winnow" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3c52e9c97a68071b23e836c9380edae937f17b9c4667bd021973efc689f618d" +dependencies = [ + "memchr", +] + +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", + "flate2", +] + +[[package]] +name = "zune-core" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a" + +[[package]] +name = "zune-inflate" +version = "0.2.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "zune-jpeg" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" +dependencies = [ + "zune-core", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..229fd677 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,28 @@ +[workspace] +members = [ + "benchmark", + "router", + "router/client", + "router/grpc-metadata", + "launcher" +] +resolver = "2" + +[workspace.package] +version = "2.0.5-dev0" +edition = "2021" +authors = ["Olivier Dehaene"] +homepage = "https://github.com/huggingface/text-generation-inference" + +[workspace.dependencies] +tokenizers = { version = "0.19.1", features = ["http"] } +hf-hub = { version = "0.3.1", features = ["tokio"] } +bitstream-io = { version = "2.3.0" } + +[profile.release] +debug = 1 +incremental = true +lto = "fat" +opt-level = 3 +codegen-units = 1 +panic = "abort" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..904936d3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,255 @@ +# Rust builder +FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + +# Python builder +# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile +FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install + +ARG PYTORCH_VERSION=2.3.0 +ARG PYTHON_VERSION=3.10 +# Keep in sync with `server/pyproject.toml +ARG CUDA_VERSION=12.1 +ARG MAMBA_VERSION=24.3.0-0 +ARG CUDA_CHANNEL=nvidia +ARG INSTALL_CHANNEL=pytorch +# Automatically set by buildx +ARG TARGETPLATFORM + +ENV PATH /opt/conda/bin:$PATH + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + ccache \ + curl \ + git && \ + rm -rf /var/lib/apt/lists/* + +# Install conda +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +# Install pytorch +# On arm64 we exit with an error code +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + +# CUDA kernels builder image +FROM pytorch-install as kernel-builder + +ARG MAX_JOBS=8 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ninja-build cmake \ + && rm -rf /var/lib/apt/lists/* + +# Build Flash Attention CUDA kernels +FROM kernel-builder as flash-att-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att Makefile + +# Build specific version of flash attention +RUN make build-flash-attention + +# Build Flash Attention v2 CUDA kernels +FROM kernel-builder as flash-att-v2-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2-cuda + +# Build Transformers exllama kernels +FROM kernel-builder as exllama-kernels-builder +WORKDIR /usr/src +COPY server/exllama_kernels/ . + +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build + +# Build Transformers exllama kernels +FROM kernel-builder as exllamav2-kernels-builder +WORKDIR /usr/src +COPY server/exllamav2_kernels/ . + +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build + +# Build Transformers awq kernels +FROM kernel-builder as awq-kernels-builder +WORKDIR /usr/src +COPY server/Makefile-awq Makefile +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq + +# Build eetq kernels +FROM kernel-builder as eetq-kernels-builder +WORKDIR /usr/src +COPY server/Makefile-eetq Makefile +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq + +# Build Transformers CUDA kernels +FROM kernel-builder as custom-kernels-builder +WORKDIR /usr/src +COPY server/custom_kernels/ . +# Build specific version of transformers +RUN python setup.py build + +# Build vllm CUDA kernels +FROM kernel-builder as vllm-builder + +WORKDIR /usr/src + +ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" + +COPY server/Makefile-vllm Makefile + +# Build specific version of vllm +RUN make build-vllm-cuda + +# Build mamba kernels +FROM kernel-builder as mamba-builder +WORKDIR /usr/src +COPY server/Makefile-selective-scan Makefile +RUN make build-all + +# Text Generation Inference base image +FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base + +# Conda env +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +WORKDIR /usr/src + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + libssl-dev \ + ca-certificates \ + make \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Copy conda with PyTorch installed +COPY --from=pytorch-install /opt/conda /opt/conda + +# Copy build artifacts from flash attention builder +COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from custom kernels builder +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllama kernels builder +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllamav2 kernels builder +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from awq kernels builder +COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from eetq kernels builder +COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from mamba builder +COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages + +# Install flash-attention dependencies +RUN pip install einops --no-cache-dir + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_cuda.txt && \ + pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# AWS Sagemaker compatible image +FROM base as sagemaker + +COPY sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] + +# Final image +FROM base + +COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh +RUN chmod +x /tgi-entrypoint.sh + +ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] diff --git a/Dockerfile_amd b/Dockerfile_amd new file mode 100644 index 00000000..92dd0ea8 --- /dev/null +++ b/Dockerfile_amd @@ -0,0 +1,217 @@ +# Rust builder +FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + +# Text Generation Inference base image for RoCm +FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + ccache \ + curl \ + git \ + make \ + libssl-dev \ + g++ \ + # Needed to build VLLM & flash. + rocthrust-dev \ + hipsparse-dev \ + hipblas-dev \ + hipblaslt-dev \ + rocblas-dev \ + hiprand-dev \ + rocrand-dev \ + miopen-hip-dev \ + hipfft-dev \ + hipcub-dev \ + hipsolver-dev \ + rccl-dev \ + cmake \ + python3-dev && \ + rm -rf /var/lib/apt/lists/* + +# Keep in sync with `server/pyproject.toml +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTORCH_VERSION='2.3.0' +ARG ROCM_VERSION='6.0.2' +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + mamba init && \ + rm ~/mambaforge.sh + +# Install flash-attention, torch dependencies +RUN pip install numpy einops ninja --no-cache-dir + +RUN conda install intel::mkl-static intel::mkl-include +RUN pip uninstall -y triton && \ + git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ + cd triton/python && \ + pip install . + +RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir + +ARG _GLIBCXX_USE_CXX11_ABI="1" +ARG CMAKE_PREFIX_PATH="/opt/conda" +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" +ARG BUILD_CAFFE2="0" \ + BUILD_CAFFE2_OPS="0" \ + USE_CUDA="0" \ + USE_ROCM="1" \ + BUILD_TEST="0" \ + USE_FBGEMM="0" \ + USE_NNPACK="0" \ + USE_QNNPACK="0" \ + USE_XNNPACK="0" \ + USE_FLASH_ATTENTION="1" \ + USE_MEM_EFF_ATTENTION="0" + +RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install + +# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm +ENV HIP_FORCE_DEV_KERNARG=1 + +# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. +# However, Triton requires a tunning for each prompt length, which is prohibitive. +ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 + +FROM base AS kernel-builder + +# # Build vllm kernels +FROM kernel-builder AS vllm-builder +WORKDIR /usr/src + +COPY server/Makefile-vllm Makefile + +# Build specific version of vllm +RUN make build-vllm-rocm + +# Build Flash Attention v2 kernels +FROM kernel-builder AS flash-att-v2-builder +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2-rocm + +# Build Transformers CUDA kernels (gpt-neox and bloom) +FROM kernel-builder as custom-kernels-builder +WORKDIR /usr/src +COPY server/custom_kernels/ . +RUN python setup.py build + +# Build exllama kernels +FROM kernel-builder as exllama-kernels-builder +WORKDIR /usr/src +COPY server/exllama_kernels/ . + +RUN python setup.py build + +# Build exllama v2 kernels +FROM kernel-builder as exllamav2-kernels-builder +WORKDIR /usr/src +COPY server/exllamav2_kernels/ . + +RUN python setup.py build + +FROM base as base-copy + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from custom kernels builder +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from exllama kernels builder +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from exllamav2 kernels builder +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_rocm.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +# AWS Sagemaker compatible image +FROM base as sagemaker + +COPY sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] + +# Final image +FROM base-copy + +COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh +RUN chmod +x /tgi-entrypoint.sh + +ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] diff --git a/Dockerfile_intel b/Dockerfile_intel new file mode 100644 index 00000000..809992e1 --- /dev/null +++ b/Dockerfile_intel @@ -0,0 +1,91 @@ +FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + + +# Text Generation Inference base image for Intel +FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base + +USER root +# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it +RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ + dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb + +RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null + +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ +| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list + +RUN apt-get update && apt install -y intel-basekit xpu-smi + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + + +WORKDIR /usr/src +RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl +RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_cuda.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest +ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest +ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric +ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib +ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: +ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV CCL_ZE_IPC_EXCHANGE=sockets + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +# Final image +FROM base + +ENTRYPOINT ["text-generation-launcher"] +CMD ["--json-output"] diff --git a/Dockerfile_kvrun b/Dockerfile_kvrun new file mode 100644 index 00000000..9d2b6934 --- /dev/null +++ b/Dockerfile_kvrun @@ -0,0 +1,265 @@ +# Rust builder +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + +# Python builder +# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile +FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install + +ARG PYTORCH_VERSION=2.3.0 +ARG PYTHON_VERSION=3.10 +# Keep in sync with `server/pyproject.toml +ARG CUDA_VERSION=12.1 +ARG FLASHINFER_URL=https://flashinfer.ai/whl/cu121/torch2.3 +ARG MAMBA_VERSION=24.3.0-0 +ARG CUDA_CHANNEL=nvidia +ARG INSTALL_CHANNEL=pytorch +# Automatically set by buildx +ARG TARGETPLATFORM + +ENV PATH /opt/conda/bin:$PATH + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + ccache \ + curl \ + g++ \ + git && \ + rm -rf /var/lib/apt/lists/* + +# Install conda +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +# Install pytorch +# On arm64 we exit with an error code +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + +# CUDA kernels builder image +FROM pytorch-install as kernel-builder + +ARG MAX_JOBS=8 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ninja-build cmake \ + && rm -rf /var/lib/apt/lists/* + +# Punica kernel builder +COPY server/punica_kernels punica_kernels +RUN pip install wheel setuptools --upgrade && \ + cd punica_kernels && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" pip install -v --no-build-isolation . +RUN pip install flashinfer -i $FLASHINFER_URL +RUN git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" pip install -vvv --no-build-isolation -e . +RUN git clone https://github.com/casper-hansen/AutoAWQ && cd AutoAWQ && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" pip install -e . + +# Build Flash Attention CUDA kernels +FROM kernel-builder as flash-att-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att Makefile + +# Build specific version of flash attention +RUN make build-flash-attention + +# Build Flash Attention v2 CUDA kernels +FROM kernel-builder as flash-att-v2-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2-cuda + +# Build Transformers exllama kernels +FROM kernel-builder as exllama-kernels-builder +WORKDIR /usr/src +COPY server/exllama_kernels/ . + +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build + +# Build Transformers exllama kernels +FROM kernel-builder as exllamav2-kernels-builder +WORKDIR /usr/src +COPY server/exllamav2_kernels/ . + +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build + +# Build Transformers awq kernels +FROM kernel-builder as awq-kernels-builder +WORKDIR /usr/src +COPY server/Makefile-awq Makefile +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq + +# Build eetq kernels +FROM kernel-builder as eetq-kernels-builder +WORKDIR /usr/src +COPY server/Makefile-eetq Makefile +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq + +# Build Transformers CUDA kernels +FROM kernel-builder as custom-kernels-builder +WORKDIR /usr/src +COPY server/custom_kernels/ . +# Build specific version of transformers +RUN python setup.py build + +# Build vllm CUDA kernels +FROM kernel-builder as vllm-builder + +WORKDIR /usr/src + +ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0+PTX" + +COPY server/Makefile-vllm Makefile + +# Build specific version of vllm +RUN make build-vllm-cuda + +# Build mamba kernels +FROM kernel-builder as mamba-builder +WORKDIR /usr/src +COPY server/Makefile-selective-scan Makefile +RUN make build-all + +# Text Generation Inference base image +FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base + +# Conda env +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +WORKDIR /usr/src + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + libssl-dev \ + ca-certificates \ + make \ + curl \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Copy conda with PyTorch installed +COPY --from=pytorch-install /opt/conda /opt/conda + +# Copy build artifacts from flash attention builder +COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from custom kernels builder +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllama kernels builder +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllamav2 kernels builder +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from awq kernels builder +COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from eetq kernels builder +COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from mamba builder +COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages + +# Install flash-attention dependencies +RUN pip install einops --no-cache-dir + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_cuda.txt && \ + pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +# AWS Sagemaker compatible image +FROM base as sagemaker + +COPY sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] + +# Final image +FROM base + +COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh +RUN chmod +x /tgi-entrypoint.sh + +ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..7d0e8034 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 Hugging Face + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile index 79d20f0b..ccf99e0a 100644 --- a/Makefile +++ b/Makefile @@ -1,50 +1,56 @@ -codebase: - rm -rf build - mkdir build - cp -r third_party/text-generation-inference/benchmark build/ - cp -r third_party/text-generation-inference/clients build/ - cp -r third_party/text-generation-inference/integration-tests build/ - cp -r third_party/text-generation-inference/launcher build/ - cp -r third_party/text-generation-inference/load_tests build/ - cp -r third_party/text-generation-inference/proto build/ - cp -r third_party/text-generation-inference/router build/ - cp -r third_party/text-generation-inference/server build/ - cp third_party/text-generation-inference/Cargo*.* build/ - cp -r server build/ - cp -r proto build/ - cp -r router build/ - cp -r launcher build/ - cd build/server && make gen-server - -install-server: - cd build/server && make install - cd build/server/punica_kernels && pip install -v --no-build-isolation . - -install-custom-kernels: - if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd build/server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi - -install-router: - cd build/router && cargo install --path . - -install-launcher: - cd build/launcher && cargo install --path . - -install-benchmark: - cd build/benchmark && cargo install --path . - -install: codebase install-server install-router install-launcher install-custom-kernels - -rust-tests: install-router install-launcher - cargo test - -integration-tests: install-integration-tests - pytest -s -vv -m "not private" integration-tests - -update-integration-tests: install-integration-tests - pytest -s -vv --snapshot-update integration-tests - -python-server-tests: - HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" build/server/tests - -python-client-tests: - pytest build/clients/python/tests \ No newline at end of file +install-punica-kernel: + pip install wheel setuptools --upgrade + cd server/punica_kernels && pip install -v --no-build-isolation . + +install-server: + cd server && make install + +install-custom-kernels: + if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi + +install-integration-tests: + cd integration-tests && pip install -r requirements.txt + cd clients/python && pip install . + +install-router: + cd router && cargo install --path . + +install-launcher: + cd launcher && cargo install --path . + +install-benchmark: + cd benchmark && cargo install --path . + +install: install-server install-router install-launcher install-custom-kernels install-punica-kernel + +server-dev: + cd server && make run-dev + +router-dev: + cd router && cargo run -- --port 8080 + +rust-tests: install-router install-launcher + cargo test + +integration-tests: install-integration-tests + pytest -s -vv -m "not private" integration-tests + +update-integration-tests: install-integration-tests + pytest -s -vv --snapshot-update integration-tests + +python-server-tests: + HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests + +python-client-tests: + pytest clients/python/tests + +python-tests: python-server-tests python-client-tests + +run-falcon-7b-instruct: + text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080 + +run-falcon-7b-instruct-quantize: + text-generation-launcher --model-id tiiuae/falcon-7b-instruct --quantize bitsandbytes --port 8080 + +clean: + rm -rf target aml diff --git a/README.md b/README.md index f9a44fd4..e60c6e6b 100644 --- a/README.md +++ b/README.md @@ -1,117 +1,122 @@ -# kv.run -(Limited) comparison of popular model serving solutions - -| Solution | Inference backend | Serving backend | Advanced kernel support | Model support | -|-----------------|-------------------|----------------------|--------------------------------------------------------------------------------------------------|----------------------------| -| Huggingface TGI | Pytorch | HF TGI (Rust) | Paged + Flash attention | Language | -| Deepspeed MII | PyTorch | Deepspeed (Python) | [DeepSpeed-Kernels](https://github.com/microsoft/DeepSpeed-Kernels) | Language | -| TensorRT-LLM | TensorRT-LLM | TensorRT-LLM (C++) | [TensorRT XQA](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/XQA-kernel.md) | Language | -| vLLM | vLLM | vLLM (Python) | Paged + Flash attention | Language | -| kv.run | PyTorch | HF TGI + more (Rust) | Paged + Flash attention, [FlashInfer](https://github.com/flashinfer-ai/flashinfer) | Language, diffusion (exp.) | - - - -## Installation -#### Sync submodules -```shell -git submodule sync -git submodule update --init -``` - -#### Install Rust -[Script](https://rustup.rs/): -```shell -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -``` -#### Install Protobuf -```shell -sudo apt-get install libssl-dev gcc -y -PROTOC_ZIP=protoc-21.12-linux-x86_64.zip -curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP -sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc -sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*' -rm -f $PROTOC_ZIP -``` - -#### Build Code Base -```shell -make install -``` - -#### Install Kernel Libraries -```shell -# Install FlashInfer -# For CUDA 12.1 & torch 2.3 -pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3 -# For other CUDA & torch versions, please check https://docs.flashinfer.ai/installation.html - -# Install Flash and Paged Attention -cd build/server -make install-flash-attention & make install-vllm-cuda -``` -You can debug/edit code in the build folder. When done, use python copy_back.py to copy changes back to the original src folder. - -## Usages -#### Local API tests -```shell -cd build/server/examples & python test_local_api.py -``` -#### Local UI demo -(Inherited from [Punica](https://github.com/punica-ai/punica)) - -[demo.mp4](https://github.com/mlsys-io/kv.run/assets/12567967/977b09fb-bd90-4757-85ab-e5fc2a58cd93) - -#### Deploy services -```shell -text-generation-launcher --model-id tjluyao/llama-3-8b --lora-ids tjluyao/llama-3-8b-math;tjluyao/llama-3-8b-zh -``` -#### Using quantized models -Add --quantize [Method] to the command above, for example: -```shell -text-generation-launcher --model-id TechxGenus/gemma-2b-GPTQ --lora-ids tjluyao/gemma-2b-it-math --quantize gptq -``` -The supported quantization methods include: -- AWQ: 4-bit. Need specific quantized model. -- EETQ: 8-bit. Can work for any model. -- GPTQ: 4-bit. Need specific quantized model. -- Bitandbytes: 8-bit. Can work for any model. - -For AWQ and EETQ quantization, you need to build their specific kernels: -```shell -# AWQ -cd build/server & make install-awq -# EETQ -cd build/server & make install-eetq -``` - -## Model support matrix -Note: L = Language, I = Image - -| Model | MOE | Size | Modality | Quantization | Tensor Parallelism | FlashInfer | Multi-LoRA | -|------------------------------------------------------------------------------|------|-------|----------|--------------|--------------------|------------|------------| -| [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) | | 9B | L, I ⇒ L | | | | | -| [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) | | 8B | L, I ⇒ L | | | | | -| [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) | | 13B | L, I ⇒ L | | | | | -| [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) | | 7B | L ⇒ L | | | ✔ | ✔ | -| [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | | 8B | L ⇒ L | | | ✔ | ✔ | -| [Phi 1.5](https://huggingface.co/microsoft/phi-1_5) | | 1.3B | L ⇒ L | | | | | -| [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | | 3.8B | L ⇒ L | | | | | -| [Gemma](https://huggingface.co/google/gemma-2b) | | 2B | L ⇒ L | | | ✔ | ✔ | -| [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) | | 104B | L ⇒ L | | | | | -| [Dbrx](https://huggingface.co/databricks/dbrx-instruct) | ✔ | 132B | L ⇒ L | | | | | -| [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) | | 2.8B | L ⇒ L | | | | | -| [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) | | 7B | L ⇒ L | | | | | -| [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) | ✔ | 8x22B | L ⇒ L | | | | | -| [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) | | 1.1B | L ⇒ L | | | | | -| [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) | | 7B | L ⇒ L | | | | | -| [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) | | 7B | L ⇒ L | | ✔ | | | -| [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) | | 15B | L ⇒ L | | | | | -| [Qwen 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) | | 15B | L ⇒ L | | | | | -| [Opt](https://huggingface.co/facebook/opt-6.7b) | | 6.7B | L ⇒ L | | | | | -| [T5](https://huggingface.co/google-t5/t5-11b) | | 11B | L ⇒ L | | | | | -| [Galactica](https://huggingface.co/facebook/galactica-120b) | | 120B | L ⇒ L | | | | | -| [SantaCoder](https://huggingface.co/bigcode/santacoder) | | 1.1B | L ⇒ L | | | | | -| [Bloom](https://huggingface.co/bigscience/bloom-560m) | | 560M | L ⇒ L | | | | | -| [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct) | | 7B | L ⇒ L | | | | | -| [Gpt2](https://huggingface.co/openai-community/gpt2) | | 124M | L ⇒ L | | | | | -| [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) | | 20B | L ⇒ L | | ✔ | | | +# kv.run +(Limited) comparison of popular model serving solutions + +| Solution | Inference backend | Serving backend | Advanced kernel support | Model support | +|-----------------|-------------------|----------------------|--------------------------------------------------------------------------------------------------|----------------------------| +| Huggingface TGI | Pytorch | HF TGI (Rust) | Paged + Flash attention | Language | +| Deepspeed MII | PyTorch | Deepspeed (Python) | [DeepSpeed-Kernels](https://github.com/microsoft/DeepSpeed-Kernels) | Language | +| TensorRT-LLM | TensorRT-LLM | TensorRT-LLM (C++) | [TensorRT XQA](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/XQA-kernel.md) | Language | +| vLLM | vLLM | vLLM (Python) | Paged + Flash attention | Language | +| kv.run | PyTorch | HF TGI + more (Rust) | Paged + Flash attention, [FlashInfer](https://github.com/flashinfer-ai/flashinfer) | Language, diffusion (exp.) | + + + +## Installation +#### Sync submodules +```shell +git submodule sync +git submodule update --init +``` + +#### Install Rust +[Script](https://rustup.rs/): +```shell +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh +``` +#### Install Protobuf +```shell +sudo apt-get install libssl-dev gcc -y +PROTOC_ZIP=protoc-21.12-linux-x86_64.zip +curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP +sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc +sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*' +rm -f $PROTOC_ZIP +``` + +#### Build Code Base +```shell +make install +``` + +#### Install Kernel Libraries +```shell +# Install FlashInfer +# For CUDA 12.1 & torch 2.3 +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3 +# For other CUDA & torch versions, please check https://docs.flashinfer.ai/installation.html + +# Install Flash and Paged Attention +cd server +make install-flash-attention && make install-vllm-cuda +``` +You can debug/edit code in the build folder. When done, use python copy_back.py to copy changes back to the original src folder. + +## Usages +#### Local API tests +```shell +cd server/examples && python test_local_api.py +``` +#### Local UI demo +(Inherited from [Punica](https://github.com/punica-ai/punica)) + +[demo.mp4](https://github.com/mlsys-io/kv.run/assets/12567967/977b09fb-bd90-4757-85ab-e5fc2a58cd93) + +#### Deploy services +```shell +text-generation-launcher --model-id tjluyao/llama-3-8b --lora-ids tjluyao/llama-3-8b-math;tjluyao/llama-3-8b-zh +``` +#### Using quantized models +Add --quantize [Method] to the command above, for example: +```shell +text-generation-launcher --model-id TechxGenus/gemma-2b-GPTQ --lora-ids tjluyao/gemma-2b-it-math --quantize gptq +``` +The supported quantization methods include: +- AWQ: 4-bit. Need specific quantized model. +- EETQ: 8-bit. Can work for any model. +- GPTQ: 4-bit. Need specific quantized model. +- Bitandbytes: 8-bit. Can work for any model. + +For AWQ and EETQ quantization, you need to build their specific kernels: +```shell +# AWQ +cd server && make install-awq +git clone https://github.com/casper-hansen/AutoAWQ && cd AutoAWQ +pip install -e . +# EETQ +cd server && make install-eetq +# GTPQ +git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ +pip install -vvv --no-build-isolation -e . +``` + +## Model support matrix +Note: L = Language, I = Image + +| Model | MOE | Size | Modality | Quantization | Tensor Parallelism | FlashInfer | Multi-LoRA | +|------------------------------------------------------------------------------|------|-------|----------|--------------|--------------------|------------|------------| +| [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) | | 9B | L, I ⇒ L | | | | | +| [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) | | 8B | L, I ⇒ L | | | | | +| [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) | | 13B | L, I ⇒ L | | | | | +| [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) | | 7B | L ⇒ L | | | ✔ | ✔ | +| [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | | 8B | L ⇒ L | | | ✔ | ✔ | +| [Phi 1.5](https://huggingface.co/microsoft/phi-1_5) | | 1.3B | L ⇒ L | | | | | +| [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | | 3.8B | L ⇒ L | | | | | +| [Gemma](https://huggingface.co/google/gemma-2b) | | 2B | L ⇒ L | | | ✔ | ✔ | +| [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) | | 104B | L ⇒ L | | | | | +| [Dbrx](https://huggingface.co/databricks/dbrx-instruct) | ✔ | 132B | L ⇒ L | | | | | +| [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) | | 2.8B | L ⇒ L | | | | | +| [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) | | 7B | L ⇒ L | | | | | +| [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) | ✔ | 8x22B | L ⇒ L | | | | | +| [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) | | 1.1B | L ⇒ L | | | | | +| [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) | | 7B | L ⇒ L | | | | | +| [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) | | 7B | L ⇒ L | | ✔ | | | +| [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) | | 15B | L ⇒ L | | | | | +| [Qwen 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) | | 15B | L ⇒ L | | | | | +| [Opt](https://huggingface.co/facebook/opt-6.7b) | | 6.7B | L ⇒ L | | | | | +| [T5](https://huggingface.co/google-t5/t5-11b) | | 11B | L ⇒ L | | | | | +| [Galactica](https://huggingface.co/facebook/galactica-120b) | | 120B | L ⇒ L | | | | | +| [SantaCoder](https://huggingface.co/bigcode/santacoder) | | 1.1B | L ⇒ L | | | | | +| [Bloom](https://huggingface.co/bigscience/bloom-560m) | | 560M | L ⇒ L | | | | | +| [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct) | | 7B | L ⇒ L | | | | | +| [Gpt2](https://huggingface.co/openai-community/gpt2) | | 124M | L ⇒ L | | | | | +| [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) | | 20B | L ⇒ L | | ✔ | | | diff --git a/assets/architecture.png b/assets/architecture.png new file mode 100644 index 00000000..1bcd1283 Binary files /dev/null and b/assets/architecture.png differ diff --git a/assets/benchmark.png b/assets/benchmark.png new file mode 100644 index 00000000..64d538a0 Binary files /dev/null and b/assets/benchmark.png differ diff --git a/assets/tgi_grafana.json b/assets/tgi_grafana.json new file mode 100644 index 00000000..5f5a74ad --- /dev/null +++ b/assets/tgi_grafana.json @@ -0,0 +1,3999 @@ +{ + "__inputs": [ + { + "name": "DS_PROMETHEUS_EKS API INFERENCE PROD", + "label": "Prometheus EKS API Inference Prod", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__elements": {}, + "__requires": [ + { + "type": "panel", + "id": "gauge", + "name": "Gauge", + "version": "" + }, + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "10.0.2" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + }, + { + "type": "panel", + "id": "timeseries", + "name": "Time series", + "version": "" + } + ], + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 2, + "id": 551, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "fieldMinMax": false, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 1000 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 0, + "y": 0 + }, + "id": 49, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "(histogram_quantile(0.5, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m]))) * 1000) > 0", + "hide": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "(histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m]))) * 1000) > 0", + "hide": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "C" + }, + { + "datasource": { + "name": "Expression", + "type": "__expr__", + "uid": "__expr__" + }, + "expression": "$B + $C", + "hide": false, + "refId": "D", + "type": "math" + } + ], + "title": "Time to first token", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 9, + "y": 0 + }, + "id": 44, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "(histogram_quantile(0.5, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m]))) * 1000)>0", + "instant": false, + "range": true, + "refId": "A" + } + ], + "title": "Decode per-token latency", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 7, + "w": 7, + "x": 17, + "y": 0 + }, + "id": 45, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "sum((rate(tgi_request_generated_tokens_sum{container=\"$service\"}[10m]) / rate(tgi_request_generated_tokens_count{container=\"$service\"}[10m]))>0)", + "instant": false, + "range": true, + "refId": "A" + } + ], + "title": "Throughput (generated tok/s)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 7 + }, + "id": 48, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Number of tokens per prompt", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 7 + }, + "id": 30, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Number of generated tokens per request", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 15 + }, + "id": 20, + "panels": [], + "title": "General", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 30, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 0, + "y": 16 + }, + "id": 4, + "maxDataPoints": 100, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "sum(increase(tgi_request_success{container=\"$service\"}[1m]))", + "legendFormat": "Success", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "sum(increase(tgi_request_failure{container=\"$service\"}[1m])) by (err)", + "hide": false, + "legendFormat": "Error: {{err}}", + "range": true, + "refId": "B" + } + ], + "title": "Requests", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 13, + "w": 9, + "x": 6, + "y": 16 + }, + "id": 6, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Mean Time Per Token quantiles", + "type": "timeseries" + }, + { + "cards": {}, + "color": { + "cardColor": "#5794F2", + "colorScale": "linear", + "colorScheme": "interpolateSpectral", + "exponent": 0.5, + "min": 0, + "mode": "opacity" + }, + "dataFormat": "tsbuckets", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 13, + "w": 9, + "x": 15, + "y": 16 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 13, + "legend": { + "show": false + }, + "maxDataPoints": 25, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 1, + "reverse": false, + "unit": "s" + } + }, + "pluginVersion": "10.4.2", + "reverseYBuckets": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "exemplar": true, + "expr": "sum(increase(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[5m])) by (le)", + "format": "heatmap", + "interval": "", + "legendFormat": "{{ le }}", + "range": true, + "refId": "A" + } + ], + "title": "Mean Time Per Token", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "yAxis": { + "decimals": 1, + "format": "s", + "logBase": 1, + "show": true + }, + "yBucketBound": "auto" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "percentage", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "orange", + "value": 70 + }, + { + "color": "red", + "value": 85 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 3, + "x": 0, + "y": 24 + }, + "id": 18, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "9.1.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "count(tgi_request_count{container=\"$service\"})", + "legendFormat": "Replicas", + "range": true, + "refId": "A" + } + ], + "title": "Number of replicas", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "mappings": [], + "thresholds": { + "mode": "percentage", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "orange", + "value": 70 + }, + { + "color": "red", + "value": 85 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 3, + "x": 3, + "y": 24 + }, + "id": 32, + "options": { + "minVizHeight": 75, + "minVizWidth": 75, + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "sizing": "auto" + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "sum(tgi_queue_size{container=\"$service\"})", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Queue Size", + "type": "gauge" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 29 + }, + "id": 26, + "panels": [], + "title": "Batching", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "bars", + "fillOpacity": 50, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "normal" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 6, + "x": 0, + "y": 30 + }, + "id": 29, + "maxDataPoints": 40, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "9.1.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "avg(tgi_batch_current_max_tokens{container=\"$service\"})", + "legendFormat": "{{ pod }}", + "range": true, + "refId": "A" + } + ], + "title": "Max tokens per batch", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 9, + "w": 4, + "x": 6, + "y": 30 + }, + "id": 33, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Speculated Tokens", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 9, + "w": 5, + "x": 10, + "y": 30 + }, + "id": 46, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Prompt Tokens", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 9, + "w": 9, + "x": 15, + "y": 30 + }, + "id": 8, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_duration_bucket{container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_duration_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_duration_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Latency quantiles", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "bars", + "fillOpacity": 50, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "normal" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 0, + "y": 35 + }, + "id": 27, + "maxDataPoints": 40, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "9.1.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "avg(tgi_batch_current_size{container=\"$service\"})", + "legendFormat": "{{ pod }}", + "range": true, + "refId": "A" + } + ], + "title": "Batch Size", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 30, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 9, + "w": 6, + "x": 0, + "y": 39 + }, + "id": 28, + "maxDataPoints": 100, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "sum(increase(tgi_batch_concat{container=\"$service\"}[1m])) by (reason)", + "hide": false, + "legendFormat": "Reason: {{ reason }}", + "range": true, + "refId": "B" + } + ], + "title": "Concatenates", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 9, + "w": 9, + "x": 6, + "y": 39 + }, + "id": 31, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Queue quantiles", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 48 + }, + "id": 22, + "panels": [], + "title": "Prefill", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 12, + "x": 0, + "y": 49 + }, + "id": 7, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Prefill Quantiles", + "type": "timeseries" + }, + { + "cards": {}, + "color": { + "cardColor": "#5794F2", + "colorScale": "linear", + "colorScheme": "interpolateSpectral", + "exponent": 0.5, + "min": 0, + "mode": "opacity" + }, + "dataFormat": "tsbuckets", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 11, + "w": 12, + "x": 12, + "y": 49 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 14, + "legend": { + "show": false + }, + "maxDataPoints": 25, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 1, + "reverse": false, + "unit": "s" + } + }, + "pluginVersion": "10.4.2", + "reverseYBuckets": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "exemplar": true, + "expr": "sum(increase(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[5m])) by (le)", + "format": "heatmap", + "interval": "", + "legendFormat": "{{ le }}", + "range": true, + "refId": "A" + } + ], + "title": "Prefill Latency", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "yAxis": { + "decimals": 1, + "format": "s", + "logBase": 1, + "show": true + }, + "yBucketBound": "auto" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 60 + }, + "id": 24, + "panels": [], + "title": "Decode", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 12, + "x": 0, + "y": 61 + }, + "id": 11, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Decode quantiles", + "type": "timeseries" + }, + { + "cards": {}, + "color": { + "cardColor": "#5794F2", + "colorScale": "linear", + "colorScheme": "interpolateSpectral", + "exponent": 0.5, + "min": 0, + "mode": "opacity" + }, + "dataFormat": "tsbuckets", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 11, + "w": 12, + "x": 12, + "y": 61 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 15, + "legend": { + "show": false + }, + "maxDataPoints": 25, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 1, + "reverse": false, + "unit": "s" + } + }, + "pluginVersion": "10.4.2", + "reverseYBuckets": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "exemplar": true, + "expr": "sum(increase(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)", + "format": "heatmap", + "interval": "", + "legendFormat": "{{ le }}", + "range": true, + "refId": "A" + } + ], + "title": "Decode Latency", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "yAxis": { + "decimals": 1, + "format": "s", + "logBase": 1, + "show": true + }, + "yBucketBound": "auto" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 72 + }, + "id": 43, + "panels": [], + "title": "Debug", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 0, + "y": 73 + }, + "id": 38, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Forward quantiles", + "type": "timeseries" + }, + { + "cards": {}, + "color": { + "cardColor": "#5794F2", + "colorScale": "linear", + "colorScheme": "interpolateSpectral", + "exponent": 0.5, + "min": 0, + "mode": "opacity" + }, + "dataFormat": "tsbuckets", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 6, + "y": 73 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 35, + "legend": { + "show": false + }, + "maxDataPoints": 25, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 1, + "reverse": false, + "unit": "s" + } + }, + "pluginVersion": "10.4.2", + "reverseYBuckets": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "exemplar": true, + "expr": "sum(increase(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)", + "format": "heatmap", + "interval": "", + "legendFormat": "{{ le }}", + "range": true, + "refId": "A" + } + ], + "title": "Forward Latency", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "yAxis": { + "decimals": 1, + "format": "s", + "logBase": 1, + "show": true + }, + "yBucketBound": "auto" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 12, + "y": 73 + }, + "id": 34, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Token Decode quantiles", + "type": "timeseries" + }, + { + "cards": {}, + "color": { + "cardColor": "#5794F2", + "colorScale": "linear", + "colorScheme": "interpolateSpectral", + "exponent": 0.5, + "min": 0, + "mode": "opacity" + }, + "dataFormat": "tsbuckets", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 18, + "y": 73 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 40, + "legend": { + "show": false + }, + "maxDataPoints": 25, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 1, + "reverse": false, + "unit": "s" + } + }, + "pluginVersion": "10.4.2", + "reverseYBuckets": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "exemplar": true, + "expr": "sum(increase(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)", + "format": "heatmap", + "interval": "", + "legendFormat": "{{ le }}", + "range": true, + "refId": "A" + } + ], + "title": "Token Decode Latency", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "yAxis": { + "decimals": 1, + "format": "s", + "logBase": 1, + "show": true + }, + "yBucketBound": "auto" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 0, + "y": 84 + }, + "id": 42, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Filter Batch quantiles", + "type": "timeseries" + }, + { + "cards": {}, + "color": { + "cardColor": "#5794F2", + "colorScale": "linear", + "colorScheme": "interpolateSpectral", + "exponent": 0.5, + "min": 0, + "mode": "opacity" + }, + "dataFormat": "tsbuckets", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 6, + "y": 84 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 39, + "legend": { + "show": false + }, + "maxDataPoints": 25, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 1, + "reverse": false, + "unit": "s" + } + }, + "pluginVersion": "10.4.2", + "reverseYBuckets": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "exemplar": true, + "expr": "sum(increase(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)", + "format": "heatmap", + "interval": "", + "legendFormat": "{{ le }}", + "range": true, + "refId": "A" + } + ], + "title": "Filter Batch Latency", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "yAxis": { + "decimals": 1, + "format": "s", + "logBase": 1, + "show": true + }, + "yBucketBound": "auto" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "p50" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p90" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "p99" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 12, + "y": 84 + }, + "id": 36, + "options": { + "legend": { + "calcs": [ + "min", + "max" + ], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p90", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))", + "hide": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Batch Concat quantiles", + "type": "timeseries" + }, + { + "cards": {}, + "color": { + "cardColor": "#5794F2", + "colorScale": "linear", + "colorScheme": "interpolateSpectral", + "exponent": 0.5, + "min": 0, + "mode": "opacity" + }, + "dataFormat": "tsbuckets", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 11, + "w": 6, + "x": 18, + "y": 84 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 41, + "legend": { + "show": false + }, + "maxDataPoints": 25, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Spectral", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 1, + "reverse": false, + "unit": "s" + } + }, + "pluginVersion": "10.4.2", + "reverseYBuckets": false, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "editorMode": "code", + "exemplar": true, + "expr": "sum(increase(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)", + "format": "heatmap", + "interval": "", + "legendFormat": "{{ le }}", + "range": true, + "refId": "A" + } + ], + "title": "Batch Concat latency", + "tooltip": { + "show": true, + "showHistogram": false + }, + "type": "heatmap", + "xAxis": { + "show": true + }, + "yAxis": { + "decimals": 1, + "format": "s", + "logBase": 1, + "show": true + }, + "yBucketBound": "auto" + } + ], + "refresh": "", + "schemaVersion": 39, + "tags": [], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "gpu-txt-gen-cohereforai-c4ai-command-r-plu-ba7f1", + "value": "gpu-txt-gen-cohereforai-c4ai-command-r-plu-ba7f1" + }, + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}" + }, + "definition": "label_values(tgi_request_count, container)", + "hide": 0, + "includeAll": false, + "multi": false, + "name": "service", + "options": [], + "query": { + "query": "label_values(tgi_request_count, container)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 1, + "type": "query" + } + ] + }, + "time": { + "from": "now-30m", + "to": "now-30s" + }, + "timepicker": { + "nowDelay": "30s" + }, + "timezone": "", + "title": "Text Generation Inference", + "uid": "RHSk7EL4kdqsd", + "version": 12, + "weekStart": "" +} diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml new file mode 100644 index 00000000..756460e0 --- /dev/null +++ b/benchmark/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "text-generation-benchmark" +description = "Text Generation Benchmarking tool" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-benchmark" +path = "src/main.rs" + +[dependencies] +average = "0.14" +clap = { version = "4.4.5", features = ["derive", "env"] } +crossterm = "0.27" +float-ord = "0.3.2" +serde = {version = "1.0.188", features = ["derive"]} +serde_json = "1.0" +tabled = "0.14.0" +text-generation-client = { path = "../router/client" } +thiserror = "1.0.48" +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } +tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]} +tracing = "0.1.37" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +hf-hub = { workspace = true } diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..17a02a30 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,30 @@ +
+ +# Text Generation Inference benchmarking tool + +![benchmark](../assets/benchmark.png) + +
+ +A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha) +and powered by [tui](https://github.com/tui-rs-revival/ratatui). + +## Install + +```shell +make install-benchmark +``` + +## Run + +First, start `text-generation-inference`: + +```shell +text-generation-launcher --model-id bigscience/bloom-560m +``` + +Then run the benchmarking tool: + +```shell +text-generation-benchmark --tokenizer-name bigscience/bloom-560m +``` diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs new file mode 100644 index 00000000..a0a9313a --- /dev/null +++ b/benchmark/src/app.rs @@ -0,0 +1,692 @@ +/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs +use crate::generation::{Decode, Message, Prefill}; +use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use text_generation_client::ClientError; +use tokio::sync::mpsc; +use tui::backend::Backend; +use tui::layout::{Alignment, Constraint, Direction, Layout}; +use tui::style::{Color, Modifier, Style}; +use tui::text::{Line, Span}; +use tui::widgets::{ + Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs, +}; +use tui::{symbols, Frame}; + +/// TUI powered App +pub(crate) struct App { + pub(crate) running: bool, + pub(crate) data: Data, + completed_runs: Vec, + completed_batch: usize, + current_batch: usize, + current_tab: usize, + touched_tab: bool, + zoom: bool, + is_error: bool, + tokenizer_name: String, + sequence_length: u32, + decode_length: u32, + n_run: usize, + receiver: mpsc::Receiver>, +} + +impl App { + pub(crate) fn new( + receiver: mpsc::Receiver>, + tokenizer_name: String, + sequence_length: u32, + decode_length: u32, + n_run: usize, + batch_size: Vec, + ) -> Self { + let current_tab = 0; + + let completed_runs: Vec = (0..batch_size.len()).map(|_| 0).collect(); + let completed_batch = 0; + let current_batch = 0; + let is_error = false; + + let data = Data::new(n_run, batch_size); + + Self { + running: true, + data, + completed_runs, + completed_batch, + current_batch, + current_tab, + touched_tab: false, + zoom: false, + is_error, + tokenizer_name, + sequence_length, + decode_length, + n_run, + receiver, + } + } + + /// Handle crossterm key events + pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) { + match key_event { + // Increase and wrap tab + KeyEvent { + code: KeyCode::Right, + .. + } + | KeyEvent { + code: KeyCode::Tab, .. + } => { + self.touched_tab = true; + self.current_tab = (self.current_tab + 1) % self.data.batch_size.len(); + } + // Decrease and wrap tab + KeyEvent { + code: KeyCode::Left, + .. + } => { + self.touched_tab = true; + if self.current_tab > 0 { + self.current_tab -= 1; + } else { + self.current_tab = self.data.batch_size.len() - 1; + } + } + // Zoom on throughput/latency fig + KeyEvent { + code: KeyCode::Char('+'), + .. + } => { + self.zoom = true; + } + // Unzoom on throughput/latency fig + KeyEvent { + code: KeyCode::Char('-'), + .. + } => { + self.zoom = false; + } + // Quit + KeyEvent { + code: KeyCode::Char('q'), + .. + } + | KeyEvent { + code: KeyCode::Char('c'), + modifiers: KeyModifiers::CONTROL, + .. + } => { + self.running = false; + } + _ => (), + } + } + + /// Get all pending messages from generation task + pub(crate) fn tick(&mut self) { + while let Ok(message) = self.receiver.try_recv() { + match message { + Ok(message) => match message { + Message::Prefill(step) => self.data.push_prefill(step, self.current_batch), + Message::Decode(step) => self.data.push_decode(step, self.current_batch), + Message::EndRun => { + self.completed_runs[self.current_batch] += 1; + } + Message::EndBatch => { + self.data.end_batch(self.current_batch); + self.completed_batch += 1; + + if self.current_batch < self.data.batch_size.len() - 1 { + // Only go to next tab if the user never touched the tab keys + if !self.touched_tab { + self.current_tab += 1; + } + + self.current_batch += 1; + } + } + Message::Warmup => {} + }, + Err(_) => self.is_error = true, + } + } + } + + /// Render frame + pub fn render(&mut self, f: &mut Frame<'_, B>) { + let batch_progress = + (self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0); + let run_progress = + (self.completed_runs[self.current_batch] as f64 / self.n_run as f64).clamp(0.0, 1.0); + + // Vertical layout + let row5 = Layout::default() + .direction(Direction::Vertical) + .constraints( + [ + Constraint::Length(1), + Constraint::Length(3), + Constraint::Length(3), + Constraint::Length(13), + Constraint::Min(10), + ] + .as_ref(), + ) + .split(f.size()); + + // Top row horizontal layout + let top = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) + .split(row5[2]); + + // Mid row horizontal layout + let mid = Layout::default() + .direction(Direction::Horizontal) + .constraints( + [ + Constraint::Percentage(25), + Constraint::Percentage(25), + Constraint::Percentage(25), + Constraint::Percentage(25), + ] + .as_ref(), + ) + .split(row5[3]); + + // Left mid row vertical layout + let prefill_text = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref()) + .split(mid[0]); + + // Right mid row vertical layout + let decode_text = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref()) + .split(mid[2]); + let decode_text_latency = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) + .split(decode_text[0]); + + // Bottom row horizontal layout + let bottom = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) + .split(row5[4]); + + // Title + let title = Block::default() + .borders(Borders::NONE) + .title(format!( + "Model: {} | Sequence Length: {} | Decode Length: {}", + self.tokenizer_name, self.sequence_length, self.decode_length + )) + .style( + Style::default() + .add_modifier(Modifier::BOLD) + .fg(Color::White), + ); + f.render_widget(title, row5[0]); + + // Helper + let helper = Block::default() + .borders(Borders::NONE) + .title("<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom") + .title_alignment(Alignment::Right) + .style(Style::default().fg(Color::White)); + f.render_widget(helper, row5[0]); + + // Batch tabs + let titles = self + .data + .batch_size + .iter() + .map(|b| { + Line::from(vec![Span::styled( + format!("Batch: {b}"), + Style::default().fg(Color::White), + )]) + }) + .collect(); + let tabs = Tabs::new(titles) + .block(Block::default().borders(Borders::ALL).title("Tabs")) + .select(self.current_tab) + .style(Style::default().fg(Color::LightCyan)) + .highlight_style( + Style::default() + .add_modifier(Modifier::BOLD) + .bg(Color::Black), + ); + f.render_widget(tabs, row5[1]); + + // Total progress bar + let color = if self.is_error { + Color::Red + } else { + Color::LightGreen + }; + let batch_gauge = progress_gauge( + "Total Progress", + format!("{} / {}", self.completed_batch, self.data.batch_size.len()), + batch_progress, + color, + ); + f.render_widget(batch_gauge, top[0]); + + // Batch progress Bar + let color = if self.is_error { + Color::Red + } else { + Color::LightBlue + }; + let run_gauge = progress_gauge( + "Batch Progress", + format!( + "{} / {}", + self.completed_runs[self.current_batch], self.n_run + ), + run_progress, + color, + ); + f.render_widget(run_gauge, top[1]); + + // Prefill text infos + let prefill_latency_block = latency_paragraph( + &mut self.data.prefill_latencies[self.current_tab], + "Prefill", + ); + let prefill_throughput_block = + throughput_paragraph(&self.data.prefill_throughputs[self.current_tab], "Prefill"); + + f.render_widget(prefill_latency_block, prefill_text[0]); + f.render_widget(prefill_throughput_block, prefill_text[1]); + + // Prefill latency histogram + let histo_width = 7; + let bins = if mid[1].width < 2 { + 0 + } else { + (mid[1].width as usize - 2) / (histo_width + 1) + } + .max(2); + + let histo_data = + latency_histogram_data(&self.data.prefill_latencies[self.current_tab], bins); + let histo_data_str: Vec<(&str, u64)> = + histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect(); + let prefill_histogram = + latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16); + f.render_widget(prefill_histogram, mid[1]); + + // Decode text info + let decode_latency_block = latency_paragraph( + &mut self.data.decode_latencies[self.current_tab], + "Decode Total", + ); + let decode_token_latency_block = latency_paragraph( + &mut self.data.decode_token_latencies[self.current_tab], + "Decode Token", + ); + let decode_throughput_block = + throughput_paragraph(&self.data.decode_throughputs[self.current_tab], "Decode"); + f.render_widget(decode_latency_block, decode_text_latency[0]); + f.render_widget(decode_token_latency_block, decode_text_latency[1]); + f.render_widget(decode_throughput_block, decode_text[1]); + + // Decode latency histogram + let histo_data = + latency_histogram_data(&self.data.decode_latencies[self.current_tab], bins); + let histo_data_str: Vec<(&str, u64)> = + histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect(); + let decode_histogram = + latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16); + f.render_widget(decode_histogram, mid[3]); + + // Prefill latency/throughput chart + let prefill_latency_throughput_chart = latency_throughput_chart( + &self.data.prefill_batch_latency_throughput, + &self.data.batch_size, + self.zoom, + "Prefill", + ); + f.render_widget(prefill_latency_throughput_chart, bottom[0]); + + // Decode latency/throughput chart + let decode_latency_throughput_chart = latency_throughput_chart( + &self.data.decode_batch_latency_throughput, + &self.data.batch_size, + self.zoom, + "Decode", + ); + f.render_widget(decode_latency_throughput_chart, bottom[1]); + } +} + +/// App internal data struct +pub(crate) struct Data { + pub(crate) batch_size: Vec, + pub(crate) prefill_latencies: Vec>, + pub(crate) prefill_throughputs: Vec>, + pub(crate) decode_latencies: Vec>, + pub(crate) decode_token_latencies: Vec>, + pub(crate) decode_throughputs: Vec>, + pub(crate) prefill_batch_latency_throughput: Vec<(f64, f64)>, + pub(crate) decode_batch_latency_throughput: Vec<(f64, f64)>, +} + +impl Data { + fn new(n_run: usize, batch_size: Vec) -> Self { + let prefill_latencies: Vec> = (0..batch_size.len()) + .map(|_| Vec::with_capacity(n_run)) + .collect(); + let prefill_throughputs: Vec> = prefill_latencies.clone(); + + let decode_latencies: Vec> = prefill_latencies.clone(); + let decode_token_latencies: Vec> = decode_latencies.clone(); + let decode_throughputs: Vec> = prefill_throughputs.clone(); + + let prefill_batch_latency_throughput: Vec<(f64, f64)> = + Vec::with_capacity(batch_size.len()); + let decode_batch_latency_throughput: Vec<(f64, f64)> = + prefill_batch_latency_throughput.clone(); + + Self { + batch_size, + prefill_latencies, + prefill_throughputs, + decode_latencies, + decode_token_latencies, + decode_throughputs, + prefill_batch_latency_throughput, + decode_batch_latency_throughput, + } + } + + fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) { + let latency = prefill.latency.as_micros() as f64 / 1000.0; + self.prefill_latencies[batch_idx].push(latency); + self.prefill_throughputs[batch_idx].push(prefill.throughput); + } + + fn push_decode(&mut self, decode: Decode, batch_idx: usize) { + let latency = decode.latency.as_micros() as f64 / 1000.0; + let token_latency = decode.token_latency.as_micros() as f64 / 1000.0; + self.decode_latencies[batch_idx].push(latency); + self.decode_token_latencies[batch_idx].push(token_latency); + self.decode_throughputs[batch_idx].push(decode.throughput); + } + + fn end_batch(&mut self, batch_idx: usize) { + self.prefill_batch_latency_throughput.push(( + self.prefill_latencies[batch_idx].iter().sum::() + / self.prefill_latencies[batch_idx].len() as f64, + self.prefill_throughputs[batch_idx].iter().sum::() + / self.prefill_throughputs[batch_idx].len() as f64, + )); + self.decode_batch_latency_throughput.push(( + self.decode_latencies[batch_idx].iter().sum::() + / self.decode_latencies[batch_idx].len() as f64, + self.decode_throughputs[batch_idx].iter().sum::() + / self.decode_throughputs[batch_idx].len() as f64, + )); + } +} + +/// Progress bar +fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge { + Gauge::default() + .block(Block::default().title(title).borders(Borders::ALL)) + .gauge_style(Style::default().fg(color)) + .label(Span::raw(label)) + .ratio(progress) +} + +/// Throughput paragraph +fn throughput_paragraph<'a>(throughput: &[f64], name: &'static str) -> Paragraph<'a> { + // Throughput average/high/low texts + let throughput_texts = statis_spans(throughput, "tokens/secs"); + + // Throughput block + Paragraph::new(throughput_texts).block( + Block::default() + .title(Span::raw(format!("{name} Throughput"))) + .borders(Borders::ALL), + ) +} + +/// Latency paragraph +fn latency_paragraph<'a>(latency: &mut [f64], name: &'static str) -> Paragraph<'a> { + // Latency average/high/low texts + let mut latency_texts = statis_spans(latency, "ms"); + + // Sort latency for percentiles + float_ord::sort(latency); + let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]); + + // Latency p50/p90/p99 texts + let colors = [Color::LightGreen, Color::LightYellow, Color::LightRed]; + for (i, (name, value)) in latency_percentiles.iter().enumerate() { + let span = Line::from(vec![Span::styled( + format!("{name}: {value:.2} ms"), + Style::default().fg(colors[i]), + )]); + latency_texts.push(span); + } + + Paragraph::new(latency_texts).block( + Block::default() + .title(Span::raw(format!("{name} Latency"))) + .borders(Borders::ALL), + ) +} + +/// Average/High/Low spans +fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { + vec![ + Line::from(vec![Span::styled( + format!( + "Average: {:.2} {unit}", + data.iter().sum::() / data.len() as f64 + ), + Style::default().fg(Color::LightBlue), + )]), + Line::from(vec![Span::styled( + format!( + "Lowest: {:.2} {unit}", + data.iter() + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN) + ), + Style::default().fg(Color::Reset), + )]), + Line::from(vec![Span::styled( + format!( + "Highest: {:.2} {unit}", + data.iter() + .max_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN) + ), + Style::default().fg(Color::Reset), + )]), + ] +} + +/// Latency histogram data +fn latency_histogram_data(latency: &[f64], bins: usize) -> Vec<(String, u64)> { + let histo_data: Vec<(String, u64)> = { + let histo = crate::utils::histogram(latency, bins); + histo + .into_iter() + .map(|(label, v)| (format!("{label:.2}"), v as u64)) + .collect() + }; + + histo_data +} + +/// Latency Histogram +fn latency_histogram<'a>( + histo_data_str: &'a Vec<(&'a str, u64)>, + name: &'static str, +) -> BarChart<'a> { + BarChart::default() + .block( + Block::default() + .title(format!("{name} latency histogram")) + .style(Style::default().fg(Color::LightYellow).bg(Color::Reset)) + .borders(Borders::ALL), + ) + .data(histo_data_str.as_slice()) +} + +/// Latency/Throughput chart +fn latency_throughput_chart<'a>( + latency_throughput: &'a [(f64, f64)], + batch_sizes: &'a [u32], + zoom: bool, + name: &'static str, +) -> Chart<'a> { + let latency_iter = latency_throughput.iter().map(|(l, _)| l); + let throughput_iter = latency_throughput.iter().map(|(_, t)| t); + + // Get extreme values + let min_latency: f64 = *latency_iter + .clone() + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN); + let max_latency: f64 = *latency_iter + .max_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN); + let min_throughput: f64 = *throughput_iter + .clone() + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN); + let max_throughput: f64 = *throughput_iter + .max_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN); + + // Char min max values + let min_x = if zoom { + ((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0 + } else { + 0.0 + }; + let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0; + let step_x = (max_x - min_x) / 4.0; + + // Chart min max values + let min_y = if zoom { + ((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0 + } else { + 0.0 + }; + let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0; + let step_y = (max_y - min_y) / 4.0; + + // Labels + let mut x_labels = vec![Span::styled( + format!("{min_x:.2}"), + Style::default() + .add_modifier(Modifier::BOLD) + .fg(Color::Gray) + .bg(Color::Reset), + )]; + for i in 0..3 { + x_labels.push(Span::styled( + format!("{:.2}", min_x + ((i + 1) as f64 * step_x)), + Style::default().fg(Color::Gray).bg(Color::Reset), + )); + } + x_labels.push(Span::styled( + format!("{max_x:.2}"), + Style::default() + .add_modifier(Modifier::BOLD) + .fg(Color::Gray) + .bg(Color::Reset), + )); + + // Labels + let mut y_labels = vec![Span::styled( + format!("{min_y:.2}"), + Style::default() + .add_modifier(Modifier::BOLD) + .fg(Color::Gray) + .bg(Color::Reset), + )]; + for i in 0..3 { + y_labels.push(Span::styled( + format!("{:.2}", min_y + ((i + 1) as f64 * step_y)), + Style::default().fg(Color::Gray).bg(Color::Reset), + )); + } + y_labels.push(Span::styled( + format!("{max_y:.2}"), + Style::default() + .add_modifier(Modifier::BOLD) + .fg(Color::Gray) + .bg(Color::Reset), + )); + + // Chart dataset + let colors = color_vec(); + let datasets: Vec = (0..latency_throughput.len()) + .map(|i| { + let color_idx = i % colors.len(); + + Dataset::default() + .name(batch_sizes[i].to_string()) + .marker(symbols::Marker::Block) + .style(Style::default().fg(colors[color_idx])) + .graph_type(GraphType::Scatter) + .data(&latency_throughput[i..(i + 1)]) + }) + .collect(); + + // Chart + Chart::new(datasets) + .style(Style::default().fg(Color::Cyan).bg(Color::Reset)) + .block( + Block::default() + .title(Span::styled( + format!("{name} throughput over latency"), + Style::default().fg(Color::Gray).bg(Color::Reset), + )) + .borders(Borders::ALL), + ) + .x_axis( + Axis::default() + .title("ms") + .style(Style::default().fg(Color::Gray).bg(Color::Reset)) + .labels(x_labels) + .bounds([min_x, max_x]), + ) + .y_axis( + Axis::default() + .title("tokens/secs") + .style(Style::default().fg(Color::Gray).bg(Color::Reset)) + .labels(y_labels) + .bounds([min_y, max_y]), + ) +} + +// Colors for latency/throughput chart +fn color_vec() -> Vec { + vec![ + Color::Red, + Color::Green, + Color::Yellow, + Color::Blue, + Color::Magenta, + Color::Cyan, + Color::Gray, + Color::DarkGray, + Color::LightRed, + Color::LightGreen, + Color::LightYellow, + Color::LightBlue, + Color::LightMagenta, + Color::LightCyan, + ] +} diff --git a/benchmark/src/event.rs b/benchmark/src/event.rs new file mode 100644 index 00000000..07482aed --- /dev/null +++ b/benchmark/src/event.rs @@ -0,0 +1,65 @@ +/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs +use crossterm::event; +use std::time::{Duration, Instant}; +use tokio::sync::{broadcast, mpsc}; + +/// Events +#[derive(Debug)] +pub(crate) enum Event { + /// Terminal tick. + Tick, + /// Key press. + Key(event::KeyEvent), + /// Terminal resize. + Resize, +} + +pub(crate) async fn terminal_event_task( + fps: u32, + event_sender: mpsc::Sender, + mut shutdown_receiver: broadcast::Receiver<()>, + _shutdown_guard_sender: mpsc::Sender<()>, +) { + // End task if a message is received on shutdown_receiver + // _shutdown_guard_sender will be dropped once the task is finished + tokio::select! { + _ = event_loop(fps, event_sender) => { + }, + _ = shutdown_receiver.recv() => {} + } +} + +/// Main event loop +async fn event_loop(fps: u32, event_sender: mpsc::Sender) { + // Frame budget + let per_frame = Duration::from_secs(1) / fps; + + // When was last frame executed + let mut last_frame = Instant::now(); + + loop { + // Sleep to avoid blocking the thread for too long + if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) { + tokio::time::sleep(sleep).await; + } + + // Get crossterm event and send a new one over the channel + if event::poll(Duration::from_secs(0)).expect("no events available") { + match event::read().expect("unable to read event") { + event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()), + event::Event::Resize(_w, _h) => { + event_sender.send(Event::Resize).await.unwrap_or(()) + } + _ => (), + } + } + + // Frame budget exceeded + if last_frame.elapsed() >= per_frame { + // Send tick + event_sender.send(Event::Tick).await.unwrap_or(()); + // Rest last_frame time + last_frame = Instant::now(); + } + } +} diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs new file mode 100644 index 00000000..5135f02f --- /dev/null +++ b/benchmark/src/generation.rs @@ -0,0 +1,228 @@ +use std::time::{Duration, Instant}; +use text_generation_client::{ + Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient, + StoppingCriteriaParameters, +}; +use tokenizers::{Tokenizer, TruncationDirection}; +use tokio::sync::{broadcast, mpsc}; + +const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; + +#[derive(Debug, Clone)] +pub(crate) struct Prefill { + pub(crate) latency: Duration, + pub(crate) throughput: f64, +} + +#[derive(Debug, Clone)] +pub(crate) struct Decode { + pub(crate) latency: Duration, + pub(crate) token_latency: Duration, + pub(crate) throughput: f64, +} + +#[derive(Debug)] +pub(crate) enum Message { + Warmup, + Prefill(Prefill), + Decode(Decode), + EndRun, + EndBatch, +} + +/// Benchmarking task +#[allow(clippy::too_many_arguments)] +pub(crate) async fn generation_task( + tokenizer: Tokenizer, + batch_size: Vec, + sequence_length: u32, + decode_length: u32, + top_n_tokens: Option, + n_runs: usize, + warmups: usize, + parameters: NextTokenChooserParameters, + client: ShardedClient, + run_sender: mpsc::Sender>, + mut shutdown_receiver: broadcast::Receiver<()>, + _shutdown_guard_sender: mpsc::Sender<()>, +) { + // End task if a message is received on shutdown_receiver + // _shutdown_guard_sender will be dropped once the task is finished + tokio::select! { + res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => { + if let Err(err) = res { + run_sender.send(Err(err)).await.unwrap_or(()); + } + }, + _ = shutdown_receiver.recv() => {} + } +} + +/// Benchmark prefill/decode +#[allow(clippy::too_many_arguments)] +async fn generate_runs( + tokenizer: Tokenizer, + batch_size: Vec, + sequence_length: u32, + decode_length: u32, + top_n_tokens: Option, + n_runs: usize, + warmups: usize, + parameters: NextTokenChooserParameters, + mut client: ShardedClient, + run_sender: mpsc::Sender>, +) -> Result<(), ClientError> { + // Create a dummy sequence + let sequence = create_sequence(sequence_length, tokenizer); + + for b in batch_size { + // Warmups on batch size + for _ in 0..warmups { + let (_, decode_batch) = prefill( + sequence.clone(), + sequence_length, + b, + decode_length, + parameters.clone(), + top_n_tokens, + &mut client, + ) + .await?; + let _ = decode(decode_batch, &mut client).await?; + // Send warmup message + run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); + } + + for _ in 0..n_runs { + let (prefill, decode_batch) = prefill( + sequence.clone(), + sequence_length, + b, + decode_length, + parameters.clone(), + top_n_tokens, + &mut client, + ) + .await?; + // Send prefill message + run_sender + .send(Ok(Message::Prefill(prefill))) + .await + .unwrap_or(()); + + let decode = decode(decode_batch, &mut client).await?; + + // Send decode message + run_sender + .send(Ok(Message::Decode(decode))) + .await + .unwrap_or(()); + + // Send run ended message + run_sender.send(Ok(Message::EndRun)).await.unwrap_or(()); + } + // Batch ended + run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(()); + } + Ok(()) +} + +// Run a prefill step +async fn prefill( + sequence: String, + sequence_length: u32, + batch_size: u32, + decode_length: u32, + parameters: NextTokenChooserParameters, + top_n_tokens: Option, + client: &mut ShardedClient, +) -> Result<(Prefill, CachedBatch), ClientError> { + // Create requests + let requests = (0..batch_size) + .map(|id| Request { + id: id.into(), + prefill_logprobs: false, + inputs: sequence.clone(), + truncate: sequence_length, + parameters: Some(parameters.clone()), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: decode_length, + stop_sequences: vec![], + ignore_eos_token: true, // Will not stop even if a eos token is generated + }), + top_n_tokens: top_n_tokens.unwrap_or(0), + lora_id: None, + }) + .collect(); + + let batch = Batch { + id: 0, + requests, + size: batch_size, + max_tokens: batch_size * (sequence_length + decode_length), + }; + + // Run prefill + let start_time = Instant::now(); + let (_, decode_batch, _) = client.prefill(batch.clone()).await?; + + // Get latency + let latency = start_time.elapsed(); + + // Compute throughput from latency and batch size + let throughput = batch_size as f64 / latency.as_secs_f64(); + + // Decode batch cannot be empty + let decode_batch = decode_batch.expect("decode_batch is None. This is a bug."); + + let step = Prefill { + latency, + throughput, + }; + + Ok((step, decode_batch)) +} + +/// Run a full decode +async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result { + let mut decode_length = 0; + let batch_size = batch.size; + + let start_time = Instant::now(); + + // Full decode over decode length + let mut next_batch = Some(batch); + while let Some(batch) = next_batch { + let result = client.decode(vec![batch]).await?; + next_batch = result.1; + decode_length += 1; + } + + // Get latency + let latency = start_time.elapsed(); + let token_latency = latency / decode_length; + + // Compute throughput from latency, batch size and decode length + let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64(); + + let step = Decode { + latency, + token_latency, + throughput, + }; + Ok(step) +} + +/// Create a dummy sequence of the correct length +fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String { + let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len(); + // Repeat lorem ipsum to cover sequence length + let string_sequence = + LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len()); + // Encode sequence + let mut encoding = tokenizer.encode(string_sequence, true).unwrap(); + // Truncate to sequence_length + encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left); + // Decode + tokenizer.decode(encoding.get_ids(), false).unwrap() +} diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs new file mode 100644 index 00000000..638c6514 --- /dev/null +++ b/benchmark/src/lib.rs @@ -0,0 +1,160 @@ +mod app; +mod event; +mod generation; +mod table; +mod utils; + +use crate::app::App; +use crate::event::Event; +use crossterm::ExecutableCommand; +use std::io; +use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; +use tokenizers::Tokenizer; +use tokio::sync::{broadcast, mpsc}; +use tui::backend::CrosstermBackend; +use tui::Terminal; + +/// Run benchmarking app +#[allow(clippy::too_many_arguments)] +pub async fn run( + tokenizer_name: String, + tokenizer: Tokenizer, + batch_size: Vec, + sequence_length: u32, + decode_length: u32, + top_n_tokens: Option, + n_runs: usize, + warmups: usize, + temperature: Option, + top_k: Option, + top_p: Option, + typical_p: Option, + repetition_penalty: Option, + frequency_penalty: Option, + watermark: bool, + do_sample: bool, + client: ShardedClient, +) -> Result<(), std::io::Error> { + let parameters = NextTokenChooserParameters { + temperature: temperature.unwrap_or(1.0), + top_k: top_k.unwrap_or(0), + top_p: top_p.unwrap_or(1.0), + typical_p: typical_p.unwrap_or(1.0), + do_sample, + seed: 0, + repetition_penalty: repetition_penalty.unwrap_or(1.0), + frequency_penalty: frequency_penalty.unwrap_or(0.0), + watermark, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }; + + // Initialize terminal properties + crossterm::terminal::enable_raw_mode()?; + io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; + io::stdout().execute(crossterm::cursor::Hide)?; + + // Initialize terminal + let mut terminal = { + let backend = CrosstermBackend::new(io::stdout()); + Terminal::new(backend)? + }; + + // Create message channel between generation_task and app + let (run_sender, run_receiver) = mpsc::channel(8); + // Crossterm event channel + let (event_sender, mut event_receiver) = mpsc::channel(8); + // Shutdown channel to terminate tasks + let (shutdown_sender, _) = broadcast::channel(1); + // Channel to check if tasks terminated + let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1); + + // Create generation task + tokio::spawn(generation::generation_task( + tokenizer, + batch_size.clone(), + sequence_length, + decode_length, + top_n_tokens, + n_runs, + warmups, + parameters, + client, + run_sender, + shutdown_sender.subscribe(), + shutdown_guard_sender.clone(), + )); + + // Create event task + tokio::spawn(event::terminal_event_task( + 250, + event_sender, + shutdown_sender.subscribe(), + shutdown_guard_sender.clone(), + )); + + // Drop our end of shutdown sender + drop(shutdown_guard_sender); + + // Create App + let mut app = App::new( + run_receiver, + tokenizer_name.clone(), + sequence_length, + decode_length, + n_runs, + batch_size, + ); + + while app.running { + // Draw frame + terminal.draw(|frame| app.render(frame))?; + + // Await a new event from event handling task + match event_receiver.recv().await { + None => break, + // Update app state + Some(event) => match event { + Event::Tick => app.tick(), + Event::Key(key_event) => app.handle_key_event(key_event), + _ => {} + }, + } + } + + // Ask tasks to shutdown + let _ = shutdown_sender.send(()); + // Wait for tasks to shutdown + let _ = shutdown_guard_receiver.recv().await; + + // Revert terminal to original view + io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; + crossterm::terminal::disable_raw_mode()?; + io::stdout().execute(crossterm::cursor::Show)?; + + let parameters_table = table::parameters_table( + tokenizer_name, + sequence_length, + decode_length, + top_n_tokens, + n_runs, + warmups, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + frequency_penalty, + watermark, + do_sample, + ); + println!("\n{parameters_table}\n"); + + let latency_table = table::latency_table(&app.data); + println!("\n{latency_table}\n"); + + let throughput_table = table::throughput_table(&app.data); + println!("\n{throughput_table}\n"); + + Ok(()) +} diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs new file mode 100644 index 00000000..2d89e045 --- /dev/null +++ b/benchmark/src/main.rs @@ -0,0 +1,222 @@ +/// Text Generation Inference benchmarking tool +/// +/// Inspired by the great Oha app: https://github.com/hatoo/oha +/// and: https://github.com/orhun/rust-tui-template +use clap::Parser; +use std::path::Path; +use text_generation_client::ShardedClient; +use tokenizers::{FromPretrainedParameters, Tokenizer}; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::EnvFilter; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + /// The name of the tokenizer (as in model_id on the huggingface hub, or local path). + #[clap(short, long, env)] + tokenizer_name: String, + + /// The revision to use for the tokenizer if on the hub. + #[clap(default_value = "main", long, env)] + revision: String, + + /// The various batch sizes to benchmark for, the idea is to get enough + /// batching to start seeing increased latency, this usually means you're + /// moving from memory bound (usual as BS=1) to compute bound, and this is + /// a sweet spot for the maximum batch size for the model under test + #[clap(short, long)] + batch_size: Option>, + + /// This is the initial prompt sent to the text-generation-server length + /// in token. Longer prompt will slow down the benchmark. Usually the + /// latency grows somewhat linearly with this for the prefill step. + /// + /// Most importantly, the prefill step is usually not the one dominating + /// your runtime, so it's ok to keep it short. + #[clap(default_value = "10", short, long, env)] + sequence_length: u32, + + /// This is how many tokens will be generated by the server and averaged out + /// to give the `decode` latency. This is the *critical* number you want to optimize for + /// LLM spend most of their time doing decoding. + /// + /// Decode latency is usually quite stable. + #[clap(default_value = "8", short, long, env)] + decode_length: u32, + + ///How many runs should we average from + #[clap(default_value = "10", short, long, env)] + runs: usize, + + /// Number of warmup cycles + #[clap(default_value = "1", short, long, env)] + warmups: usize, + + /// The location of the grpc socket. This benchmark tool bypasses the router + /// completely and directly talks to the gRPC processes + #[clap(default_value = "/tmp/text-generation-server-0", short, long, env)] + master_shard_uds_path: String, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + temperature: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_k: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_p: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + typical_p: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + repetition_penalty: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + frequency_penalty: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + watermark: bool, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + do_sample: bool, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_n_tokens: Option, +} + +fn main() -> Result<(), Box> { + init_logging(); + + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + tokenizer_name, + revision, + batch_size, + sequence_length, + decode_length, + runs, + warmups, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + frequency_penalty, + watermark, + do_sample, + master_shard_uds_path, + top_n_tokens, + } = args; + + let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); + + // Tokenizer instance + // This will only be used to validate payloads + tracing::info!("Loading tokenizer"); + let local_path = Path::new(&tokenizer_name); + let tokenizer = + if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() + { + // Load local tokenizer + tracing::info!("Found local tokenizer"); + Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap() + } else { + tracing::info!("Downloading tokenizer"); + + // Parse Huggingface hub token + let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + + // Download and instantiate tokenizer + // We need to download it outside of the Tokio runtime + let params = FromPretrainedParameters { + revision, + auth_token, + ..Default::default() + }; + Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap() + }; + tracing::info!("Tokenizer loaded"); + + // Launch Tokio runtime + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + // Instantiate sharded client from the master unix socket + tracing::info!("Connect to model server"); + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .expect("Could not connect to server"); + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .expect("Unable to clear cache"); + tracing::info!("Connected"); + + // Run app + text_generation_benchmark::run( + tokenizer_name, + tokenizer, + batch_size, + sequence_length, + decode_length, + top_n_tokens, + runs, + warmups, + temperature, + top_k, + top_p, + typical_p, + repetition_penalty, + frequency_penalty, + watermark, + do_sample, + sharded_client, + ) + .await + .unwrap(); + }); + Ok(()) +} + +/// Init logging using LOG_LEVEL +fn init_logging() { + // STDOUT/STDERR layer + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_line_number(true); + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + + tracing_subscriber::registry() + .with(env_filter) + .with(fmt_layer) + .init(); +} diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs new file mode 100644 index 00000000..1585a25f --- /dev/null +++ b/benchmark/src/table.rs @@ -0,0 +1,174 @@ +use crate::app::Data; +use tabled::settings::Merge; +use tabled::{builder::Builder, settings::Style, Table}; + +#[allow(clippy::too_many_arguments)] +pub(crate) fn parameters_table( + tokenizer_name: String, + sequence_length: u32, + decode_length: u32, + top_n_tokens: Option, + n_runs: usize, + warmups: usize, + temperature: Option, + top_k: Option, + top_p: Option, + typical_p: Option, + repetition_penalty: Option, + frequency_penalty: Option, + watermark: bool, + do_sample: bool, +) -> Table { + let mut builder = Builder::default(); + + builder.set_header(["Parameter", "Value"]); + + builder.push_record(["Model", &tokenizer_name]); + builder.push_record(["Sequence Length", &sequence_length.to_string()]); + builder.push_record(["Decode Length", &decode_length.to_string()]); + builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]); + builder.push_record(["N Runs", &n_runs.to_string()]); + builder.push_record(["Warmups", &warmups.to_string()]); + builder.push_record(["Temperature", &format!("{temperature:?}")]); + builder.push_record(["Top K", &format!("{top_k:?}")]); + builder.push_record(["Top P", &format!("{top_p:?}")]); + builder.push_record(["Typical P", &format!("{typical_p:?}")]); + builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]); + builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]); + builder.push_record(["Watermark", &watermark.to_string()]); + builder.push_record(["Do Sample", &do_sample.to_string()]); + + let mut table = builder.build(); + table.with(Style::markdown()); + table +} + +pub(crate) fn latency_table(data: &Data) -> Table { + let mut builder = Builder::default(); + + builder.set_header([ + "Step", + "Batch Size", + "Average", + "Lowest", + "Highest", + "p50", + "p90", + "p99", + ]); + + add_latencies( + &mut builder, + "Prefill", + &data.batch_size, + &data.prefill_latencies, + ); + add_latencies( + &mut builder, + "Decode (token)", + &data.batch_size, + &data.decode_token_latencies, + ); + add_latencies( + &mut builder, + "Decode (total)", + &data.batch_size, + &data.decode_latencies, + ); + + let mut table = builder.build(); + table.with(Style::markdown()).with(Merge::vertical()); + table +} + +pub(crate) fn throughput_table(data: &Data) -> Table { + let mut builder = Builder::default(); + + builder.set_header(["Step", "Batch Size", "Average", "Lowest", "Highest"]); + + add_throuhgputs( + &mut builder, + "Prefill", + &data.batch_size, + &data.prefill_throughputs, + ); + add_throuhgputs( + &mut builder, + "Decode", + &data.batch_size, + &data.decode_throughputs, + ); + + let mut table = builder.build(); + table.with(Style::markdown()).with(Merge::vertical()); + table +} + +fn add_latencies( + builder: &mut Builder, + step: &'static str, + batch_size: &[u32], + batch_latencies: &[Vec], +) { + for (i, b) in batch_size.iter().enumerate() { + let latencies = &batch_latencies[i]; + let (avg, min, max) = avg_min_max(latencies); + + let row = [ + step, + &b.to_string(), + &format_value(avg, "ms"), + &format_value(min, "ms"), + &format_value(max, "ms"), + &format_value(px(latencies, 50), "ms"), + &format_value(px(latencies, 90), "ms"), + &format_value(px(latencies, 99), "ms"), + ]; + + builder.push_record(row); + } +} + +fn add_throuhgputs( + builder: &mut Builder, + step: &'static str, + batch_size: &[u32], + batch_throughputs: &[Vec], +) { + for (i, b) in batch_size.iter().enumerate() { + let throughputs = &batch_throughputs[i]; + let (avg, min, max) = avg_min_max(throughputs); + + let row = [ + step, + &b.to_string(), + &format_value(avg, "tokens/secs"), + &format_value(min, "tokens/secs"), + &format_value(max, "tokens/secs"), + ]; + + builder.push_record(row); + } +} + +fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { + let average = data.iter().sum::() / data.len() as f64; + let min = data + .iter() + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN); + let max = data + .iter() + .max_by(|a, b| a.total_cmp(b)) + .unwrap_or(&f64::NAN); + (average, *min, *max) +} + +fn px(data: &[f64], p: u32) -> f64 { + let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; + *data.get(i).unwrap_or(&f64::NAN) +} + +fn format_value(value: f64, unit: &'static str) -> String { + format!("{:.2} {unit}", value) +} diff --git a/benchmark/src/utils.rs b/benchmark/src/utils.rs new file mode 100644 index 00000000..20469991 --- /dev/null +++ b/benchmark/src/utils.rs @@ -0,0 +1,43 @@ +/// MIT License +// +// Copyright (c) 2020 hatoo +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +use std::collections::BTreeMap; + +pub(crate) fn histogram(values: &[f64], bins: usize) -> Vec<(f64, usize)> { + assert!(bins >= 2); + let mut bucket: Vec = vec![0; bins]; + let min = values.iter().collect::().min(); + let max = values.iter().collect::().max(); + let step = (max - min) / (bins - 1) as f64; + + for &v in values { + let i = std::cmp::min(((v - min) / step).ceil() as usize, bins - 1); + bucket[i] += 1; + } + + bucket + .into_iter() + .enumerate() + .map(|(i, v)| (min + step * i as f64, v)) + .collect() +} + +pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap { + pecents + .iter() + .map(|&p| { + let i = (f64::from(p) / 100.0 * values.len() as f64) as usize; + (format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN)) + }) + .collect() +} diff --git a/clients/python/.gitignore b/clients/python/.gitignore new file mode 100644 index 00000000..5a8ecaa7 --- /dev/null +++ b/clients/python/.gitignore @@ -0,0 +1,158 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +text_generation/__pycache__/ +text_generation/pb/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +transformers +safetensors diff --git a/clients/python/Makefile b/clients/python/Makefile new file mode 100644 index 00000000..42720875 --- /dev/null +++ b/clients/python/Makefile @@ -0,0 +1,6 @@ +unit-tests: + python -m pytest --cov=text_generation tests + +install: + pip install pip --upgrade + pip install -e . diff --git a/clients/python/README.md b/clients/python/README.md new file mode 100644 index 00000000..bf37508e --- /dev/null +++ b/clients/python/README.md @@ -0,0 +1,279 @@ +# Text Generation + +The Hugging Face Text Generation Python library provides a convenient way of interfacing with a +`text-generation-inference` instance running on +[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub. + +## Get Started + +### Install + +```shell +pip install text-generation +``` + +### Inference API Usage + +```python +from text_generation import InferenceAPIClient + +client = InferenceAPIClient("bigscience/bloomz") +text = client.generate("Why is the sky blue?").generated_text +print(text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +or with the asynchronous client: + +```python +from text_generation import InferenceAPIAsyncClient + +client = InferenceAPIAsyncClient("bigscience/bloomz") +response = await client.generate("Why is the sky blue?") +print(response.generated_text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +async for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +Check all currently deployed models on the Huggingface Inference API with `Text Generation` support: + +```python +from text_generation.inference_api import deployed_models + +print(deployed_models()) +``` + +### Hugging Face Inference Endpoint usage + +```python +from text_generation import Client + +endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" + +client = Client(endpoint_url) +text = client.generate("Why is the sky blue?").generated_text +print(text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +or with the asynchronous client: + +```python +from text_generation import AsyncClient + +endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" + +client = AsyncClient(endpoint_url) +response = await client.generate("Why is the sky blue?") +print(response.generated_text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +async for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +### Types + +```python +# enum for grammar type +class GrammarType(Enum): + Json = "json" + Regex = "regex" + + +# Grammar type and value +class Grammar: + # Grammar type + type: GrammarType + # Grammar value + value: Union[str, dict] + +class Parameters: + # Activate logits sampling + do_sample: bool + # Maximum number of generated tokens + max_new_tokens: int + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] + # Whether to prepend the prompt to the generated text + return_full_text: bool + # Stop generating tokens if a member of `stop_sequences` is generated + stop: List[str] + # Random sampling seed + seed: Optional[int] + # The value used to module the logits distribution. + temperature: Optional[float] + # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: Optional[int] + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. + top_p: Optional[float] + # truncate inputs tokens to the given size + truncate: Optional[int] + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + watermark: bool + # Get generation details + details: bool + # Get decoder input token logprobs and ids + decoder_input_details: bool + # Return the N most likely tokens at each step + top_n_tokens: Optional[int] + # grammar to use for generation + grammar: Optional[Grammar] + +class Request: + # Prompt + inputs: str + # Generation parameters + parameters: Optional[Parameters] + # Whether to stream output tokens + stream: bool + +# Decoder input tokens +class InputToken: + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + # Optional since the logprob of the first token cannot be computed + logprob: Optional[float] + + +# Generated tokens +class Token: + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + logprob: Optional[float] + # Is the token a special token + # Can be used to ignore tokens when concatenating + special: bool + + +# Generation finish reason +class FinishReason(Enum): + # number of generated tokens == `max_new_tokens` + Length = "length" + # the model generated its end of sequence token + EndOfSequenceToken = "eos_token" + # the model generated a text included in `stop_sequences` + StopSequence = "stop_sequence" + + +# Additional sequences when using the `best_of` parameter +class BestOfSequence: + # Generated text + generated_text: str + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] + # Generated tokens + tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] + + +# `generate` details +class Details: + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] + # Generated tokens + tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] + # Additional sequences when using the `best_of` parameter + best_of_sequences: Optional[List[BestOfSequence]] + + +# `generate` return value +class Response: + # Generated text + generated_text: str + # Generation details + details: Details + + +# `generate_stream` details +class StreamDetails: + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + + +# `generate_stream` return value +class StreamResponse: + # Generated token + token: Token + # Most likely tokens + top_tokens: Optional[List[Token]] + # Complete generated text + # Only available when the generation is finished + generated_text: Optional[str] + # Generation details + # Only available when the generation is finished + details: Optional[StreamDetails] + +# Inference API currently deployed model +class DeployedModel: + model_id: str + sha: str +``` diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock new file mode 100644 index 00000000..148d9906 --- /dev/null +++ b/clients/python/poetry.lock @@ -0,0 +1,1163 @@ +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. + +[[package]] +name = "aiohttp" +version = "3.8.5" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.6" +files = [ + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"}, + {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"}, + {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"}, + {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"}, + {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"}, + {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"}, + {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"}, + {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"}, + {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"}, + {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"}, + {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"}, +] + +[package.dependencies] +aiosignal = ">=1.1.2" +async-timeout = ">=4.0.0a3,<5.0" +asynctest = {version = "0.13.0", markers = "python_version < \"3.8\""} +attrs = ">=17.3.0" +charset-normalizer = ">=2.0,<4.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns", "cchardet"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + +[[package]] +name = "annotated-types" +version = "0.5.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.7" +files = [ + {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"}, + {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + +[package.dependencies] +typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""} + +[[package]] +name = "asynctest" +version = "0.13.0" +description = "Enhance the standard unittest package with features for testing asyncio libraries" +optional = false +python-versions = ">=3.5" +files = [ + {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"}, + {file = "asynctest-0.13.0.tar.gz", hash = "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"}, +] + +[[package]] +name = "atomicwrites" +version = "1.4.1" +description = "Atomic file writes." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"}, +] + +[[package]] +name = "attrs" +version = "23.1.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=3.7" +files = [ + {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, + {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, +] + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[docs,tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] + +[[package]] +name = "certifi" +version = "2023.7.22" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, + {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.2.0" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, + {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "coverage" +version = "7.2.7" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"}, + {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"}, + {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"}, + {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"}, + {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"}, + {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"}, + {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"}, + {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"}, + {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"}, + {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"}, + {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"}, + {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"}, + {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"}, + {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"}, + {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"}, + {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"}, + {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"}, + {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"}, + {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"}, + {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"}, + {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"}, + {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"}, + {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"}, + {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + +[[package]] +name = "filelock" +version = "3.12.2" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.7" +files = [ + {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, + {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, +] + +[package.extras] +docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] + +[[package]] +name = "frozenlist" +version = "1.3.3" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.7" +files = [ + {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"}, + {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0"}, + {file = "frozenlist-1.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530"}, + {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7"}, + {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99"}, + {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483"}, + {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd"}, + {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf"}, + {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816"}, + {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0"}, + {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce"}, + {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f"}, + {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420"}, + {file = "frozenlist-1.3.3-cp310-cp310-win32.whl", hash = "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642"}, + {file = "frozenlist-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1"}, + {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7"}, + {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678"}, + {file = "frozenlist-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6"}, + {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8"}, + {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb"}, + {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91"}, + {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b"}, + {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4"}, + {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48"}, + {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d"}, + {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6"}, + {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4"}, + {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81"}, + {file = "frozenlist-1.3.3-cp311-cp311-win32.whl", hash = "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8"}, + {file = "frozenlist-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32"}, + {file = "frozenlist-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332"}, + {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27"}, + {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d"}, + {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e"}, + {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d"}, + {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c"}, + {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56"}, + {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420"}, + {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e"}, + {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb"}, + {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401"}, + {file = "frozenlist-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a"}, + {file = "frozenlist-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411"}, + {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a"}, + {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5"}, + {file = "frozenlist-1.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e"}, + {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c"}, + {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba"}, + {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703"}, + {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2"}, + {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448"}, + {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4"}, + {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649"}, + {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842"}, + {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13"}, + {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3"}, + {file = "frozenlist-1.3.3-cp38-cp38-win32.whl", hash = "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b"}, + {file = "frozenlist-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef"}, + {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf"}, + {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1"}, + {file = "frozenlist-1.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0"}, + {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d"}, + {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936"}, + {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5"}, + {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b"}, + {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669"}, + {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb"}, + {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784"}, + {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d"}, + {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab"}, + {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1"}, + {file = "frozenlist-1.3.3-cp39-cp39-win32.whl", hash = "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38"}, + {file = "frozenlist-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9"}, + {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"}, +] + +[[package]] +name = "fsspec" +version = "2023.1.0" +description = "File-system specification" +optional = false +python-versions = ">=3.7" +files = [ + {file = "fsspec-2023.1.0-py3-none-any.whl", hash = "sha256:b833e2e541e9e8cde0ab549414187871243177feb3d344f9d27b25a93f5d8139"}, + {file = "fsspec-2023.1.0.tar.gz", hash = "sha256:fbae7f20ff801eb5f7d0bedf81f25c787c0dfac5e982d98fa3884a9cde2b5411"}, +] + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +entrypoints = ["importlib-metadata"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + +[[package]] +name = "huggingface-hub" +version = "0.16.4" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"}, + {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic"] +quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["torch"] +typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] + +[[package]] +name = "idna" +version = "3.4" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, + {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, +] + +[[package]] +name = "importlib-metadata" +version = "6.7.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "importlib_metadata-6.7.0-py3-none-any.whl", hash = "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"}, + {file = "importlib_metadata-6.7.0.tar.gz", hash = "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4"}, +] + +[package.dependencies] +typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "multidict" +version = "6.0.4" +description = "multidict implementation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, + {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"}, + {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"}, + {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"}, + {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"}, + {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"}, + {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"}, + {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"}, + {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"}, + {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"}, + {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"}, + {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"}, + {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"}, + {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"}, + {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, +] + +[[package]] +name = "packaging" +version = "23.1" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, + {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, +] + +[[package]] +name = "pluggy" +version = "1.2.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, + {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "py" +version = "1.11.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, + {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, +] + +[[package]] +name = "pydantic" +version = "2.5.3" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic-2.5.3-py3-none-any.whl", hash = "sha256:d0caf5954bee831b6bfe7e338c32b9e30c85dfe080c843680783ac2b631673b4"}, + {file = "pydantic-2.5.3.tar.gz", hash = "sha256:b3ef57c62535b0941697cce638c08900d87fcb67e29cfa99e8a68f747f393f7a"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +importlib-metadata = {version = "*", markers = "python_version == \"3.7\""} +pydantic-core = "2.14.6" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.14.6" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.14.6-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:72f9a942d739f09cd42fffe5dc759928217649f070056f03c70df14f5770acf9"}, + {file = "pydantic_core-2.14.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6a31d98c0d69776c2576dda4b77b8e0c69ad08e8b539c25c7d0ca0dc19a50d6c"}, + {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aa90562bc079c6c290f0512b21768967f9968e4cfea84ea4ff5af5d917016e4"}, + {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:370ffecb5316ed23b667d99ce4debe53ea664b99cc37bfa2af47bc769056d534"}, + {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f85f3843bdb1fe80e8c206fe6eed7a1caeae897e496542cee499c374a85c6e08"}, + {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9862bf828112e19685b76ca499b379338fd4c5c269d897e218b2ae8fcb80139d"}, + {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036137b5ad0cb0004c75b579445a1efccd072387a36c7f217bb8efd1afbe5245"}, + {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92879bce89f91f4b2416eba4429c7b5ca22c45ef4a499c39f0c5c69257522c7c"}, + {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c08de15d50fa190d577e8591f0329a643eeaed696d7771760295998aca6bc66"}, + {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:36099c69f6b14fc2c49d7996cbf4f87ec4f0e66d1c74aa05228583225a07b590"}, + {file = "pydantic_core-2.14.6-cp310-none-win32.whl", hash = "sha256:7be719e4d2ae6c314f72844ba9d69e38dff342bc360379f7c8537c48e23034b7"}, + {file = "pydantic_core-2.14.6-cp310-none-win_amd64.whl", hash = "sha256:36fa402dcdc8ea7f1b0ddcf0df4254cc6b2e08f8cd80e7010d4c4ae6e86b2a87"}, + {file = "pydantic_core-2.14.6-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:dea7fcd62915fb150cdc373212141a30037e11b761fbced340e9db3379b892d4"}, + {file = "pydantic_core-2.14.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ffff855100bc066ff2cd3aa4a60bc9534661816b110f0243e59503ec2df38421"}, + {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b027c86c66b8627eb90e57aee1f526df77dc6d8b354ec498be9a757d513b92b"}, + {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:00b1087dabcee0b0ffd104f9f53d7d3eaddfaa314cdd6726143af6bc713aa27e"}, + {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75ec284328b60a4e91010c1acade0c30584f28a1f345bc8f72fe8b9e46ec6a96"}, + {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e1f4744eea1501404b20b0ac059ff7e3f96a97d3e3f48ce27a139e053bb370b"}, + {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2602177668f89b38b9f84b7b3435d0a72511ddef45dc14446811759b82235a1"}, + {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c8edaea3089bf908dd27da8f5d9e395c5b4dc092dbcce9b65e7156099b4b937"}, + {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:478e9e7b360dfec451daafe286998d4a1eeaecf6d69c427b834ae771cad4b622"}, + {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b6ca36c12a5120bad343eef193cc0122928c5c7466121da7c20f41160ba00ba2"}, + {file = "pydantic_core-2.14.6-cp311-none-win32.whl", hash = "sha256:2b8719037e570639e6b665a4050add43134d80b687288ba3ade18b22bbb29dd2"}, + {file = "pydantic_core-2.14.6-cp311-none-win_amd64.whl", hash = "sha256:78ee52ecc088c61cce32b2d30a826f929e1708f7b9247dc3b921aec367dc1b23"}, + {file = "pydantic_core-2.14.6-cp311-none-win_arm64.whl", hash = "sha256:a19b794f8fe6569472ff77602437ec4430f9b2b9ec7a1105cfd2232f9ba355e6"}, + {file = "pydantic_core-2.14.6-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:667aa2eac9cd0700af1ddb38b7b1ef246d8cf94c85637cbb03d7757ca4c3fdec"}, + {file = "pydantic_core-2.14.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cdee837710ef6b56ebd20245b83799fce40b265b3b406e51e8ccc5b85b9099b7"}, + {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c5bcf3414367e29f83fd66f7de64509a8fd2368b1edf4351e862910727d3e51"}, + {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a92ae76f75d1915806b77cf459811e772d8f71fd1e4339c99750f0e7f6324f"}, + {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a983cca5ed1dd9a35e9e42ebf9f278d344603bfcb174ff99a5815f953925140a"}, + {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cb92f9061657287eded380d7dc455bbf115430b3aa4741bdc662d02977e7d0af"}, + {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4ace1e220b078c8e48e82c081e35002038657e4b37d403ce940fa679e57113b"}, + {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef633add81832f4b56d3b4c9408b43d530dfca29e68fb1b797dcb861a2c734cd"}, + {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7e90d6cc4aad2cc1f5e16ed56e46cebf4877c62403a311af20459c15da76fd91"}, + {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e8a5ac97ea521d7bde7621d86c30e86b798cdecd985723c4ed737a2aa9e77d0c"}, + {file = "pydantic_core-2.14.6-cp312-none-win32.whl", hash = "sha256:f27207e8ca3e5e021e2402ba942e5b4c629718e665c81b8b306f3c8b1ddbb786"}, + {file = "pydantic_core-2.14.6-cp312-none-win_amd64.whl", hash = "sha256:b3e5fe4538001bb82e2295b8d2a39356a84694c97cb73a566dc36328b9f83b40"}, + {file = "pydantic_core-2.14.6-cp312-none-win_arm64.whl", hash = "sha256:64634ccf9d671c6be242a664a33c4acf12882670b09b3f163cd00a24cffbd74e"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:24368e31be2c88bd69340fbfe741b405302993242ccb476c5c3ff48aeee1afe0"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:e33b0834f1cf779aa839975f9d8755a7c2420510c0fa1e9fa0497de77cd35d2c"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6af4b3f52cc65f8a0bc8b1cd9676f8c21ef3e9132f21fed250f6958bd7223bed"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d15687d7d7f40333bd8266f3814c591c2e2cd263fa2116e314f60d82086e353a"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:095b707bb287bfd534044166ab767bec70a9bba3175dcdc3371782175c14e43c"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94fc0e6621e07d1e91c44e016cc0b189b48db053061cc22d6298a611de8071bb"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ce830e480f6774608dedfd4a90c42aac4a7af0a711f1b52f807130c2e434c06"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a306cdd2ad3a7d795d8e617a58c3a2ed0f76c8496fb7621b6cd514eb1532cae8"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f5fa187bde8524b1e37ba894db13aadd64faa884657473b03a019f625cee9a8"}, + {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:438027a975cc213a47c5d70672e0d29776082155cfae540c4e225716586be75e"}, + {file = "pydantic_core-2.14.6-cp37-none-win32.whl", hash = "sha256:f96ae96a060a8072ceff4cfde89d261837b4294a4f28b84a28765470d502ccc6"}, + {file = "pydantic_core-2.14.6-cp37-none-win_amd64.whl", hash = "sha256:e646c0e282e960345314f42f2cea5e0b5f56938c093541ea6dbf11aec2862391"}, + {file = "pydantic_core-2.14.6-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:db453f2da3f59a348f514cfbfeb042393b68720787bbef2b4c6068ea362c8149"}, + {file = "pydantic_core-2.14.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3860c62057acd95cc84044e758e47b18dcd8871a328ebc8ccdefd18b0d26a21b"}, + {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36026d8f99c58d7044413e1b819a67ca0e0b8ebe0f25e775e6c3d1fabb3c38fb"}, + {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8ed1af8692bd8d2a29d702f1a2e6065416d76897d726e45a1775b1444f5928a7"}, + {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:314ccc4264ce7d854941231cf71b592e30d8d368a71e50197c905874feacc8a8"}, + {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:982487f8931067a32e72d40ab6b47b1628a9c5d344be7f1a4e668fb462d2da42"}, + {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dbe357bc4ddda078f79d2a36fc1dd0494a7f2fad83a0a684465b6f24b46fe80"}, + {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2f6ffc6701a0eb28648c845f4945a194dc7ab3c651f535b81793251e1185ac3d"}, + {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7f5025db12fc6de7bc1104d826d5aee1d172f9ba6ca936bf6474c2148ac336c1"}, + {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dab03ed811ed1c71d700ed08bde8431cf429bbe59e423394f0f4055f1ca0ea60"}, + {file = "pydantic_core-2.14.6-cp38-none-win32.whl", hash = "sha256:dfcbebdb3c4b6f739a91769aea5ed615023f3c88cb70df812849aef634c25fbe"}, + {file = "pydantic_core-2.14.6-cp38-none-win_amd64.whl", hash = "sha256:99b14dbea2fdb563d8b5a57c9badfcd72083f6006caf8e126b491519c7d64ca8"}, + {file = "pydantic_core-2.14.6-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:4ce8299b481bcb68e5c82002b96e411796b844d72b3e92a3fbedfe8e19813eab"}, + {file = "pydantic_core-2.14.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b9a9d92f10772d2a181b5ca339dee066ab7d1c9a34ae2421b2a52556e719756f"}, + {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd9e98b408384989ea4ab60206b8e100d8687da18b5c813c11e92fd8212a98e0"}, + {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4f86f1f318e56f5cbb282fe61eb84767aee743ebe32c7c0834690ebea50c0a6b"}, + {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86ce5fcfc3accf3a07a729779d0b86c5d0309a4764c897d86c11089be61da160"}, + {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dcf1978be02153c6a31692d4fbcc2a3f1db9da36039ead23173bc256ee3b91b"}, + {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eedf97be7bc3dbc8addcef4142f4b4164066df0c6f36397ae4aaed3eb187d8ab"}, + {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5f916acf8afbcab6bacbb376ba7dc61f845367901ecd5e328fc4d4aef2fcab0"}, + {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8a14c192c1d724c3acbfb3f10a958c55a2638391319ce8078cb36c02283959b9"}, + {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0348b1dc6b76041516e8a854ff95b21c55f5a411c3297d2ca52f5528e49d8411"}, + {file = "pydantic_core-2.14.6-cp39-none-win32.whl", hash = "sha256:de2a0645a923ba57c5527497daf8ec5df69c6eadf869e9cd46e86349146e5975"}, + {file = "pydantic_core-2.14.6-cp39-none-win_amd64.whl", hash = "sha256:aca48506a9c20f68ee61c87f2008f81f8ee99f8d7f0104bff3c47e2d148f89d9"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d5c28525c19f5bb1e09511669bb57353d22b94cf8b65f3a8d141c389a55dec95"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:78d0768ee59baa3de0f4adac9e3748b4b1fffc52143caebddfd5ea2961595277"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b93785eadaef932e4fe9c6e12ba67beb1b3f1e5495631419c784ab87e975670"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a874f21f87c485310944b2b2734cd6d318765bcbb7515eead33af9641816506e"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89f4477d915ea43b4ceea6756f63f0288941b6443a2b28c69004fe07fde0d0d"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:172de779e2a153d36ee690dbc49c6db568d7b33b18dc56b69a7514aecbcf380d"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:dfcebb950aa7e667ec226a442722134539e77c575f6cfaa423f24371bb8d2e94"}, + {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:55a23dcd98c858c0db44fc5c04fc7ed81c4b4d33c653a7c45ddaebf6563a2f66"}, + {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:4241204e4b36ab5ae466ecec5c4c16527a054c69f99bba20f6f75232a6a534e2"}, + {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e574de99d735b3fc8364cba9912c2bec2da78775eba95cbb225ef7dda6acea24"}, + {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1302a54f87b5cd8528e4d6d1bf2133b6aa7c6122ff8e9dc5220fbc1e07bffebd"}, + {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8e81e4b55930e5ffab4a68db1af431629cf2e4066dbdbfef65348b8ab804ea8"}, + {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c99462ffc538717b3e60151dfaf91125f637e801f5ab008f81c402f1dff0cd0f"}, + {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e4cf2d5829f6963a5483ec01578ee76d329eb5caf330ecd05b3edd697e7d768a"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:cf10b7d58ae4a1f07fccbf4a0a956d705356fea05fb4c70608bb6fa81d103cda"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:399ac0891c284fa8eb998bcfa323f2234858f5d2efca3950ae58c8f88830f145"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c6a5c79b28003543db3ba67d1df336f253a87d3112dac3a51b94f7d48e4c0e1"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:599c87d79cab2a6a2a9df4aefe0455e61e7d2aeede2f8577c1b7c0aec643ee8e"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43e166ad47ba900f2542a80d83f9fc65fe99eb63ceec4debec160ae729824052"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a0b5db001b98e1c649dd55afa928e75aa4087e587b9524a4992316fa23c9fba"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:747265448cb57a9f37572a488a57d873fd96bf51e5bb7edb52cfb37124516da4"}, + {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7ebe3416785f65c28f4f9441e916bfc8a54179c8dea73c23023f7086fa601c5d"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:86c963186ca5e50d5c8287b1d1c9d3f8f024cbe343d048c5bd282aec2d8641f2"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e0641b506486f0b4cd1500a2a65740243e8670a2549bb02bc4556a83af84ae03"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71d72ca5eaaa8d38c8df16b7deb1a2da4f650c41b58bb142f3fb75d5ad4a611f"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e524624eace5c59af499cd97dc18bb201dc6a7a2da24bfc66ef151c69a5f2a"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3dde6cac75e0b0902778978d3b1646ca9f438654395a362cb21d9ad34b24acf"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:00646784f6cd993b1e1c0e7b0fdcbccc375d539db95555477771c27555e3c556"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:23598acb8ccaa3d1d875ef3b35cb6376535095e9405d91a3d57a8c7db5d29341"}, + {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7f41533d7e3cf9520065f610b41ac1c76bc2161415955fbcead4981b22c7611e"}, + {file = "pydantic_core-2.14.6.tar.gz", hash = "sha256:1fd0c1d395372843fba13a51c28e3bb9d59bd7aebfeb17358ffaaa1e4dbbe948"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pytest" +version = "6.2.5" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, + {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, +] + +[package.dependencies] +atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} +attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +py = ">=1.8.2" +toml = "*" + +[package.extras] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] + +[[package]] +name = "pytest-asyncio" +version = "0.17.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.17.2.tar.gz", hash = "sha256:6d895b02432c028e6957d25fc936494e78c6305736e785d9fee408b1efbc7ff4"}, + {file = "pytest_asyncio-0.17.2-py3-none-any.whl", hash = "sha256:e0fe5dbea40516b661ef1bcfe0bd9461c2847c4ef4bb40012324f2454fb7d56d"}, +] + +[package.dependencies] +pytest = ">=6.1.0" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.8\""} + +[package.extras] +testing = ["coverage (==6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (==0.931)"] + +[[package]] +name = "pytest-cov" +version = "3.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.6" +files = [ + {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, + {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.7" +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "tqdm" +version = "4.66.1" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, + {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "typing-extensions" +version = "4.7.1" +description = "Backported and Experimental Type Hints for Python 3.7+" +optional = false +python-versions = ">=3.7" +files = [ + {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, + {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, +] + +[[package]] +name = "urllib3" +version = "2.0.5" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.7" +files = [ + {file = "urllib3-2.0.5-py3-none-any.whl", hash = "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e"}, + {file = "urllib3-2.0.5.tar.gz", hash = "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "yarl" +version = "1.9.2" +description = "Yet another URL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, + {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, + {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, + {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, + {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, + {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, + {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, + {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, + {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, + {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, + {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, + {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, + {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" +typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} + +[[package]] +name = "zipp" +version = "3.15.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.7" +files = [ + {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"}, + {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] + +[metadata] +lock-version = "2.0" +python-versions = "^3.7" +content-hash = "b7fab8703967f2616ea59a98a437cd30f97f0c8d2a06e399d688814a2a2c64f8" diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml new file mode 100644 index 00000000..2925085b --- /dev/null +++ b/clients/python/pyproject.toml @@ -0,0 +1,29 @@ +[tool.poetry] +name = "text-generation" +version = "0.7.0" +description = "Hugging Face Text Generation Python Client" +license = "Apache-2.0" +authors = ["Olivier Dehaene "] +maintainers = ["Olivier Dehaene "] +readme = "README.md" +homepage = "https://github.com/huggingface/text-generation-inference" +repository = "https://github.com/huggingface/text-generation-inference" + + +[tool.poetry.dependencies] +python = "^3.7" +pydantic = "> 2, < 3" +aiohttp = "^3.8" +huggingface-hub = ">= 0.12, < 1.0" + +[tool.poetry.dev-dependencies] +pytest = "^6.2.5" +pytest-asyncio = "^0.17.2" +pytest-cov = "^3.0.0" + +[tool.pytest.ini_options] +asyncio_mode = "auto" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py new file mode 100644 index 00000000..10a5b599 --- /dev/null +++ b/clients/python/tests/conftest.py @@ -0,0 +1,61 @@ +import pytest + +from text_generation import __version__ +from huggingface_hub.utils import build_hf_headers + + +@pytest.fixture +def flan_t5_xxl(): + return "google/flan-t5-xxl" + + +@pytest.fixture +def llama_7b(): + return "" # "meta-llama/Llama-2-7b-chat-hf" + + +@pytest.fixture +def fake_model(): + return "fake/model" + + +@pytest.fixture +def unsupported_model(): + return "gpt2" + + +@pytest.fixture +def base_url(): + return "http://127.0.0.1:3000" # "https://api-inference.huggingface.co/models" + + +@pytest.fixture +def bloom_url(base_url, bloom_model): + return f"{base_url}/{bloom_model}" + + +@pytest.fixture +def flan_t5_xxl_url(base_url, flan_t5_xxl): + return f"{base_url}/{flan_t5_xxl}" + + +@pytest.fixture +def llama_7b_url(base_url, llama_7b): + return f"{base_url}/{llama_7b}" + + +@pytest.fixture +def fake_url(base_url, fake_model): + return f"{base_url}/{fake_model}" + + +@pytest.fixture +def unsupported_url(base_url, unsupported_model): + return f"{base_url}/{unsupported_model}" + + +@pytest.fixture(scope="session") +def hf_headers(): + return build_hf_headers( + library_name="text-generation-tests", library_version=__version__ + ) diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py new file mode 100644 index 00000000..2cfd019d --- /dev/null +++ b/clients/python/tests/test_client.py @@ -0,0 +1,167 @@ +import pytest + +from text_generation import Client, AsyncClient +from text_generation.errors import NotFoundError, ValidationError +from text_generation.types import FinishReason, InputToken + + +def test_generate_lora(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) + response = client.generate( + "test", + max_new_tokens=1, + decoder_input_details=True, + lora_id="abcdabcd987/gsm8k-llama2-7b-lora-16", + ) + assert response.generated_text == "_" + + response = client.download_lora_adapter("abcdabcd987/gsm8k-llama2-7b-lora-16") + assert response.status_code == 200 + + +def test_generate(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) + response = client.generate("test", max_new_tokens=1, decoder_input_details=True) + + assert response.generated_text == "_" + assert response.details.finish_reason == FinishReason.Length + assert response.details.generated_tokens == 1 + assert response.details.seed is None + assert len(response.details.prefill) == 2 + assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None) + assert len(response.details.tokens) == 1 + assert response.details.tokens[0].id == 29918 + assert response.details.tokens[0].text == "_" + assert not response.details.tokens[0].special + + +def test_generate_best_of(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) + response = client.generate( + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True + ) + + assert response.details.seed is not None + assert response.details.best_of_sequences is not None + assert len(response.details.best_of_sequences) == 1 + assert response.details.best_of_sequences[0].seed is not None + + +def test_generate_not_found(fake_url, hf_headers): + client = Client(fake_url, hf_headers) + with pytest.raises(NotFoundError): + client.generate("test") + + +def test_generate_validation_error(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) + with pytest.raises(ValidationError): + client.generate("test", max_new_tokens=10_000) + + +def test_generate_stream(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) + responses = [ + response for response in client.generate_stream("test", max_new_tokens=1) + ] + + assert len(responses) == 1 + response = responses[0] + + assert response.generated_text == "_" + assert response.details.finish_reason == FinishReason.Length + assert response.details.generated_tokens == 1 + assert response.details.seed is None + + +def test_generate_stream_not_found(fake_url, hf_headers): + client = Client(fake_url, hf_headers) + with pytest.raises(NotFoundError): + list(client.generate_stream("test")) + + +def test_generate_stream_validation_error(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) + with pytest.raises(ValidationError): + list(client.generate_stream("test", max_new_tokens=10_000)) + + +@pytest.mark.asyncio +async def test_generate_async(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) + response = await client.generate( + "test", max_new_tokens=1, decoder_input_details=True + ) + + assert response.generated_text == "_" + assert response.details.finish_reason == FinishReason.Length + assert response.details.generated_tokens == 1 + assert response.details.seed is None + assert len(response.details.prefill) == 2 + assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None) + assert response.details.prefill[1] == InputToken( + id=1243, text="test", logprob=-10.96875 + ) + assert len(response.details.tokens) == 1 + assert response.details.tokens[0].id == 29918 + assert response.details.tokens[0].text == "_" + assert not response.details.tokens[0].special + + +@pytest.mark.asyncio +async def test_generate_async_best_of(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) + response = await client.generate( + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True + ) + + assert response.details.seed is not None + assert response.details.best_of_sequences is not None + assert len(response.details.best_of_sequences) == 1 + assert response.details.best_of_sequences[0].seed is not None + + +@pytest.mark.asyncio +async def test_generate_async_not_found(fake_url, hf_headers): + client = AsyncClient(fake_url, hf_headers) + with pytest.raises(NotFoundError): + await client.generate("test") + + +@pytest.mark.asyncio +async def test_generate_async_validation_error(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) + with pytest.raises(ValidationError): + await client.generate("test", max_new_tokens=10_000) + + +@pytest.mark.asyncio +async def test_generate_stream_async(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) + responses = [ + response async for response in client.generate_stream("test", max_new_tokens=1) + ] + + assert len(responses) == 1 + response = responses[0] + + assert response.generated_text == "_" + assert response.details.finish_reason == FinishReason.Length + assert response.details.generated_tokens == 1 + assert response.details.seed is None + + +@pytest.mark.asyncio +async def test_generate_stream_async_not_found(fake_url, hf_headers): + client = AsyncClient(fake_url, hf_headers) + with pytest.raises(NotFoundError): + async for _ in client.generate_stream("test"): + pass + + +@pytest.mark.asyncio +async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) + with pytest.raises(ValidationError): + async for _ in client.generate_stream("test", max_new_tokens=10_000): + pass diff --git a/clients/python/tests/test_errors.py b/clients/python/tests/test_errors.py new file mode 100644 index 00000000..8389ed31 --- /dev/null +++ b/clients/python/tests/test_errors.py @@ -0,0 +1,64 @@ +from text_generation.errors import ( + parse_error, + GenerationError, + IncompleteGenerationError, + OverloadedError, + ValidationError, + BadRequestError, + ShardNotReadyError, + ShardTimeoutError, + NotFoundError, + RateLimitExceededError, + UnknownError, +) + + +def test_generation_error(): + payload = {"error_type": "generation", "error": "test"} + assert isinstance(parse_error(400, payload), GenerationError) + + +def test_incomplete_generation_error(): + payload = {"error_type": "incomplete_generation", "error": "test"} + assert isinstance(parse_error(400, payload), IncompleteGenerationError) + + +def test_overloaded_error(): + payload = {"error_type": "overloaded", "error": "test"} + assert isinstance(parse_error(400, payload), OverloadedError) + + +def test_validation_error(): + payload = {"error_type": "validation", "error": "test"} + assert isinstance(parse_error(400, payload), ValidationError) + + +def test_bad_request_error(): + payload = {"error": "test"} + assert isinstance(parse_error(400, payload), BadRequestError) + + +def test_shard_not_ready_error(): + payload = {"error": "test"} + assert isinstance(parse_error(403, payload), ShardNotReadyError) + assert isinstance(parse_error(424, payload), ShardNotReadyError) + + +def test_shard_timeout_error(): + payload = {"error": "test"} + assert isinstance(parse_error(504, payload), ShardTimeoutError) + + +def test_not_found_error(): + payload = {"error": "test"} + assert isinstance(parse_error(404, payload), NotFoundError) + + +def test_rate_limit_exceeded_error(): + payload = {"error": "test"} + assert isinstance(parse_error(429, payload), RateLimitExceededError) + + +def test_unknown_error(): + payload = {"error": "test"} + assert isinstance(parse_error(500, payload), UnknownError) diff --git a/clients/python/tests/test_inference_api.py b/clients/python/tests/test_inference_api.py new file mode 100644 index 00000000..59297c26 --- /dev/null +++ b/clients/python/tests/test_inference_api.py @@ -0,0 +1,42 @@ +import pytest + +from text_generation import ( + InferenceAPIClient, + InferenceAPIAsyncClient, + Client, + AsyncClient, +) +from text_generation.errors import NotSupportedError, NotFoundError +from text_generation.inference_api import check_model_support, deployed_models + + +def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model): + assert check_model_support(flan_t5_xxl) + assert not check_model_support(unsupported_model) + + with pytest.raises(NotFoundError): + check_model_support(fake_model) + + +def test_deployed_models(): + deployed_models() + + +def test_client(flan_t5_xxl): + client = InferenceAPIClient(flan_t5_xxl) + assert isinstance(client, Client) + + +def test_client_unsupported_model(unsupported_model): + with pytest.raises(NotSupportedError): + InferenceAPIClient(unsupported_model) + + +def test_async_client(flan_t5_xxl): + client = InferenceAPIAsyncClient(flan_t5_xxl) + assert isinstance(client, AsyncClient) + + +def test_async_client_unsupported_model(unsupported_model): + with pytest.raises(NotSupportedError): + InferenceAPIAsyncClient(unsupported_model) diff --git a/clients/python/tests/test_types.py b/clients/python/tests/test_types.py new file mode 100644 index 00000000..77689ade --- /dev/null +++ b/clients/python/tests/test_types.py @@ -0,0 +1,84 @@ +import pytest + +from text_generation.types import Parameters, Request +from text_generation.errors import ValidationError + + +def test_parameters_validation(): + # Test best_of + Parameters(best_of=1) + with pytest.raises(ValidationError): + Parameters(best_of=0) + with pytest.raises(ValidationError): + Parameters(best_of=-1) + Parameters(best_of=2, do_sample=True) + with pytest.raises(ValidationError): + Parameters(best_of=2) + with pytest.raises(ValidationError): + Parameters(best_of=2, seed=1) + + # Test repetition_penalty + Parameters(repetition_penalty=1) + with pytest.raises(ValidationError): + Parameters(repetition_penalty=0) + with pytest.raises(ValidationError): + Parameters(repetition_penalty=-1) + + # Test seed + Parameters(seed=1) + with pytest.raises(ValidationError): + Parameters(seed=-1) + + # Test temperature + Parameters(temperature=1) + with pytest.raises(ValidationError): + Parameters(temperature=0) + with pytest.raises(ValidationError): + Parameters(temperature=-1) + + # Test top_k + Parameters(top_k=1) + with pytest.raises(ValidationError): + Parameters(top_k=0) + with pytest.raises(ValidationError): + Parameters(top_k=-1) + + # Test top_p + Parameters(top_p=0.5) + with pytest.raises(ValidationError): + Parameters(top_p=0) + with pytest.raises(ValidationError): + Parameters(top_p=-1) + with pytest.raises(ValidationError): + Parameters(top_p=1) + + # Test truncate + Parameters(truncate=1) + with pytest.raises(ValidationError): + Parameters(truncate=0) + with pytest.raises(ValidationError): + Parameters(truncate=-1) + + # Test typical_p + Parameters(typical_p=0.5) + with pytest.raises(ValidationError): + Parameters(typical_p=0) + with pytest.raises(ValidationError): + Parameters(typical_p=-1) + with pytest.raises(ValidationError): + Parameters(typical_p=1) + + +def test_request_validation(): + Request(inputs="test") + + with pytest.raises(ValidationError): + Request(inputs="") + + Request(inputs="test", stream=True) + Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True)) + + with pytest.raises(ValidationError): + Request( + inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True + ) diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py new file mode 100644 index 00000000..a8e67071 --- /dev/null +++ b/clients/python/text_generation/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.6.0" + +DEPRECATION_WARNING = ( + "`text_generation` clients are deprecated and will be removed in the near future. " + "Please use the `InferenceClient` from the `huggingface_hub` package instead." +) + +from text_generation.client import Client, AsyncClient +from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py new file mode 100644 index 00000000..dbcb9cef --- /dev/null +++ b/clients/python/text_generation/client.py @@ -0,0 +1,1043 @@ +import json +import requests +import warnings + +from aiohttp import ClientSession, ClientTimeout +from pydantic import ValidationError +from typing import Dict, Optional, List, AsyncIterator, Iterator, Union + +from text_generation import DEPRECATION_WARNING +from text_generation.types import ( + StreamResponse, + Response, + Request, + Parameters, + Grammar, + CompletionRequest, + Completion, + CompletionComplete, + ChatRequest, + ChatCompletionChunk, + ChatComplete, + Message, + Tool, +) +from text_generation.errors import parse_error + +# emit deprecation warnings +warnings.simplefilter("always", DeprecationWarning) + + +class Client: + """Client to make calls to a text-generation-inference instance + + Example: + + ```python + >>> from text_generation import Client + + >>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz") + >>> client.generate("Why is the sky blue?").generated_text + ' Rayleigh scattering' + + >>> result = "" + >>> for response in client.generate_stream("Why is the sky blue?"): + >>> if not response.token.special: + >>> result += response.token.text + >>> result + ' Rayleigh scattering' + ``` + """ + + def __init__( + self, + base_url: str, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + timeout: int = 10, + ): + """ + Args: + base_url (`str`): + text-generation-inference instance base url + headers (`Optional[Dict[str, str]]`): + Additional headers + cookies (`Optional[Dict[str, str]]`): + Cookies to include in the requests + timeout (`int`): + Timeout in seconds + """ + warnings.warn(DEPRECATION_WARNING, DeprecationWarning) + self.base_url = base_url + self.headers = headers + self.cookies = cookies + self.timeout = timeout + + def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + lora_id: Optional[str] = None, + ): + """ + Given a prompt, generate a response synchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + lora_id=lora_id, + ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return Completion(**payload) + else: + return self._completion_stream_response(request) + + def _completion_stream_response(self, request): + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + + def chat( + self, + messages: List[Message], + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + tool_prompt: Optional[str] = None, + tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, + lora_id: Optional[str] = None, + ): + """ + Given a list of messages, generate a response asynchronously + + Args: + messages (`List[Message]`): + List of messages + repetition_penalty (`float`): + The parameter for repetition penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + logit_bias (`List[float]`): + Adjust the likelihood of specified tokens + logprobs (`bool`): + Include log probabilities in the response + top_logprobs (`int`): + Include the `n` most likely tokens at each step + max_tokens (`int`): + Maximum number of generated tokens + n (`int`): + Generate `n` completions + presence_penalty (`float`): + The parameter for presence penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + stream (`bool`): + Stream the response + seed (`int`): + Random sampling seed + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + tools (`List[Tool]`): + List of tools to use + tool_prompt (`str`): + A prompt to be appended before the tools + tool_choice (`str`): + The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + + """ + request = ChatRequest( + model="tgi", + messages=messages, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_prompt=tool_prompt, + tool_choice=tool_choice, + stop=stop, + lora_id=lora_id, + ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return ChatComplete(**payload) + else: + return self._chat_stream_response(request) + + def _chat_stream_response(self, request): + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + + def generate( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + best_of: Optional[int] = None, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + decoder_input_details: bool = False, + top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, + lora_id: Optional[str] = None, + ) -> Response: + """ + Given a prompt, generate the following text + + Args: + prompt (`str`): + Input text + do_sample (`bool`): + Activate logits sampling + max_new_tokens (`int`): + Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs + repetition_penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + return_full_text (`bool`): + Whether to prepend the prompt to the generated text + seed (`int`): + Random sampling seed + stop_sequences (`List[str]`): + Stop generating tokens if a member of `stop_sequences` is generated + temperature (`float`): + The value used to module the logits distribution. + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids + top_n_tokens (`int`): + Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. + + Returns: + Response: generated response + """ + # Validate parameters + parameters = Parameters( + best_of=best_of, + details=True, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop_sequences if stop_sequences is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + decoder_input_details=decoder_input_details, + top_n_tokens=top_n_tokens, + grammar=grammar, + ) + request = Request( + inputs=prompt, stream=False, parameters=parameters, lora_id=lora_id + ) + + resp = requests.post( + self.base_url, + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + if resp.status_code == 404: + raise parse_error( + resp.status_code, + {"error": "Service not found.", "errory_type": "generation"}, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return Response(**payload[0]) + + def generate_stream( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, + lora_id: Optional[str] = None, + ) -> Iterator[StreamResponse]: + """ + Given a prompt, generate the following stream of tokens + + Args: + prompt (`str`): + Input text + do_sample (`bool`): + Activate logits sampling + max_new_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + return_full_text (`bool`): + Whether to prepend the prompt to the generated text + seed (`int`): + Random sampling seed + stop_sequences (`List[str]`): + Stop generating tokens if a member of `stop_sequences` is generated + temperature (`float`): + The value used to module the logits distribution. + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_n_tokens (`int`): + Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. + + Returns: + Iterator[StreamResponse]: stream of generated tokens + """ + # Validate parameters + parameters = Parameters( + best_of=None, + details=True, + decoder_input_details=False, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop_sequences if stop_sequences is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + top_n_tokens=top_n_tokens, + grammar=grammar, + ) + request = Request( + inputs=prompt, stream=True, parameters=parameters, lora_id=lora_id + ) + + resp = requests.post( + self.base_url, + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + + if resp.status_code == 404: + raise parse_error( + resp.status_code, + {"error": "Service not found.", "errory_type": "generation"}, + ) + if resp.status_code != 200: + raise parse_error(resp.status_code, resp.json()) + + # Parse ServerSentEvents + for byte_payload in resp.iter_lines(): + # Skip line + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + # Event data + if payload.startswith("data:"): + # Decode payload + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + # Parse payload + try: + response = StreamResponse(**json_payload) + except ValidationError: + # If we failed to parse the payload, then it is an error payload + raise parse_error(resp.status_code, json_payload) + yield response + + def download_lora_adapter(self, lora_id: str, hf_api_token: Optional[str] = None): + req = {} + req["lora_id"] = lora_id + req["hf_api_token"] = hf_api_token + resp = requests.post( + f"{self.base_url}/download_lora_adapter", + json=req, + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + + if resp.status_code != 200: + raise parse_error(resp.status_code, resp.json()) + return resp + + +class AsyncClient: + """Asynchronous Client to make calls to a text-generation-inference instance + + Example: + + ```python + >>> from text_generation import AsyncClient + + >>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz") + >>> response = await client.generate("Why is the sky blue?") + >>> response.generated_text + ' Rayleigh scattering' + + >>> result = "" + >>> async for response in client.generate_stream("Why is the sky blue?"): + >>> if not response.token.special: + >>> result += response.token.text + >>> result + ' Rayleigh scattering' + ``` + """ + + def __init__( + self, + base_url: str, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + timeout: int = 10, + ): + """ + Args: + base_url (`str`): + text-generation-inference instance base url + headers (`Optional[Dict[str, str]]`): + Additional headers + cookies (`Optional[Dict[str, str]]`): + Cookies to include in the requests + timeout (`int`): + Timeout in seconds + """ + warnings.warn(DEPRECATION_WARNING, DeprecationWarning) + self.base_url = base_url + self.headers = headers + self.cookies = cookies + self.timeout = ClientTimeout(timeout) + + async def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + lora_id: Optional[str] = None, + ) -> Union[Completion, AsyncIterator[CompletionComplete]]: + """ + Given a prompt, generate a response asynchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + lora_id=lora_id, + ) + if not stream: + return await self._completion_single_response(request) + else: + return self._completion_stream_response(request) + + async def _completion_single_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return Completion(**payload) + + async def _completion_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + + async def chat( + self, + messages: List[Message], + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + tool_prompt: Optional[str] = None, + tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, + lora_id: Optional[str] = None, + ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: + """ + Given a list of messages, generate a response asynchronously + + Args: + messages (`List[Message]`): + List of messages + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + logit_bias (`List[float]`): + Adjust the likelihood of specified tokens + logprobs (`bool`): + Include log probabilities in the response + top_logprobs (`int`): + Include the `n` most likely tokens at each step + max_tokens (`int`): + Maximum number of generated tokens + n (`int`): + Generate `n` completions + presence_penalty (`float`): + The parameter for presence penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + stream (`bool`): + Stream the response + seed (`int`): + Random sampling seed + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + tools (`List[Tool]`): + List of tools to use + tool_prompt (`str`): + A prompt to be appended before the tools + tool_choice (`str`): + The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + + """ + request = ChatRequest( + model="tgi", + messages=messages, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_prompt=tool_prompt, + tool_choice=tool_choice, + stop=stop, + lora_id=lora_id, + ) + if not stream: + return await self._chat_single_response(request) + else: + return self._chat_stream_response(request) + + async def _chat_single_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return ChatComplete(**payload) + + async def _chat_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + + async def generate( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + best_of: Optional[int] = None, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + decoder_input_details: bool = False, + top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, + lora_id: Optional[str] = None, + ) -> Response: + """ + Given a prompt, generate the following text asynchronously + + Args: + prompt (`str`): + Input text + do_sample (`bool`): + Activate logits sampling + max_new_tokens (`int`): + Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs + repetition_penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + return_full_text (`bool`): + Whether to prepend the prompt to the generated text + seed (`int`): + Random sampling seed + stop_sequences (`List[str]`): + Stop generating tokens if a member of `stop_sequences` is generated + temperature (`float`): + The value used to module the logits distribution. + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids + top_n_tokens (`int`): + Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. + + Returns: + Response: generated response + """ + + # Validate parameters + parameters = Parameters( + best_of=best_of, + details=True, + decoder_input_details=decoder_input_details, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop_sequences if stop_sequences is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + top_n_tokens=top_n_tokens, + grammar=grammar, + ) + request = Request( + inputs=prompt, stream=False, parameters=parameters, lora_id=lora_id + ) + + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post(self.base_url, json=request.dict()) as resp: + if resp.status == 404: + raise parse_error( + resp.status, + {"error": "Service not found.", "errory_type": "generation"}, + ) + + payload = await resp.json() + + if resp.status != 200: + raise parse_error(resp.status, payload) + return Response(**payload[0]) + + async def generate_stream( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, + lora_id: Optional[str] = None, + ) -> AsyncIterator[StreamResponse]: + """ + Given a prompt, generate the following stream of tokens asynchronously + + Args: + prompt (`str`): + Input text + do_sample (`bool`): + Activate logits sampling + max_new_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + frequency_penalty (`float`): + The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + return_full_text (`bool`): + Whether to prepend the prompt to the generated text + seed (`int`): + Random sampling seed + stop_sequences (`List[str]`): + Stop generating tokens if a member of `stop_sequences` is generated + temperature (`float`): + The value used to module the logits distribution. + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_n_tokens (`int`): + Return the `n` most likely tokens at each step + grammar (`Grammar`): + Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation + of the text to match a regular expression or JSON schema. + + Returns: + AsyncIterator[StreamResponse]: stream of generated tokens + """ + # Validate parameters + parameters = Parameters( + best_of=None, + details=True, + decoder_input_details=False, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop_sequences if stop_sequences is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + top_n_tokens=top_n_tokens, + grammar=grammar, + ) + request = Request( + inputs=prompt, stream=True, parameters=parameters, lora_id=lora_id + ) + + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post(self.base_url, json=request.dict()) as resp: + if resp.status == 404: + raise parse_error( + resp.status, + {"error": "Service not found.", "errory_type": "generation"}, + ) + if resp.status != 200: + raise parse_error(resp.status, await resp.json()) + + # Parse ServerSentEvents + async for byte_payload in resp.content: + # Skip line + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + # Event data + if payload.startswith("data:"): + # Decode payload + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + # Parse payload + try: + response = StreamResponse(**json_payload) + except ValidationError: + # If we failed to parse the payload, then it is an error payload + raise parse_error(resp.status, json_payload) + yield response diff --git a/clients/python/text_generation/errors.py b/clients/python/text_generation/errors.py new file mode 100644 index 00000000..dbf0b761 --- /dev/null +++ b/clients/python/text_generation/errors.py @@ -0,0 +1,106 @@ +from typing import Dict + + +# Text Generation Inference Errors +class ValidationError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class GenerationError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class OverloadedError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class IncompleteGenerationError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +# API Inference Errors +class BadRequestError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class ShardNotReadyError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class ShardTimeoutError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class NotFoundError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class RateLimitExceededError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class NotSupportedError(Exception): + def __init__(self, model_id: str): + message = ( + f"Model `{model_id}` is not available for inference with this client. \n" + "Use `huggingface_hub.inference_api.InferenceApi` instead." + ) + super(NotSupportedError, self).__init__(message) + + +# Unknown error +class UnknownError(Exception): + def __init__(self, message: str): + super().__init__(message) + + +def parse_error(status_code: int, payload: Dict[str, str]) -> Exception: + """ + Parse error given an HTTP status code and a json payload + + Args: + status_code (`int`): + HTTP status code + payload (`Dict[str, str]`): + Json payload + + Returns: + Exception: parsed exception + + """ + # Try to parse a Text Generation Inference error + message = payload["error"] + if "error_type" in payload: + error_type = payload["error_type"] + if error_type == "generation": + return GenerationError(message) + if error_type == "incomplete_generation": + return IncompleteGenerationError(message) + if error_type == "overloaded": + return OverloadedError(message) + if error_type == "validation": + return ValidationError(message) + + # Try to parse a APIInference error + if status_code == 400: + return BadRequestError(message) + if status_code == 403 or status_code == 424: + return ShardNotReadyError(message) + if status_code == 504: + return ShardTimeoutError(message) + if status_code == 404: + return NotFoundError(message) + if status_code == 429: + return RateLimitExceededError(message) + + # Fallback to an unknown error + return UnknownError(message) diff --git a/clients/python/text_generation/inference_api.py b/clients/python/text_generation/inference_api.py new file mode 100644 index 00000000..93b0de8d --- /dev/null +++ b/clients/python/text_generation/inference_api.py @@ -0,0 +1,168 @@ +import os +import requests + +from typing import Dict, Optional, List +from huggingface_hub.utils import build_hf_headers + +from text_generation import Client, AsyncClient, __version__ +from text_generation.types import DeployedModel +from text_generation.errors import NotSupportedError, parse_error + +INFERENCE_ENDPOINT = os.environ.get( + "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" +) + + +def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: + """ + Get all currently deployed models with text-generation-inference-support + + Returns: + List[DeployedModel]: list of all currently deployed models + """ + resp = requests.get( + f"https://api-inference.huggingface.co/framework/text-generation-inference", + headers=headers, + timeout=5, + ) + + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + + models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload] + return models + + +def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool: + """ + Check if a given model is supported by text-generation-inference + + Returns: + bool: whether the model is supported by this client + """ + resp = requests.get( + f"https://api-inference.huggingface.co/status/{repo_id}", + headers=headers, + timeout=5, + ) + + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + + framework = payload["framework"] + supported = framework == "text-generation-inference" + return supported + + +class InferenceAPIClient(Client): + """Client to make calls to the HuggingFace Inference API. + + Only supports a subset of the available text-generation or text2text-generation models that are served using + text-generation-inference + + Example: + + ```python + >>> from text_generation import InferenceAPIClient + + >>> client = InferenceAPIClient("bigscience/bloomz") + >>> client.generate("Why is the sky blue?").generated_text + ' Rayleigh scattering' + + >>> result = "" + >>> for response in client.generate_stream("Why is the sky blue?"): + >>> if not response.token.special: + >>> result += response.token.text + >>> result + ' Rayleigh scattering' + ``` + """ + + def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10): + """ + Init headers and API information + + Args: + repo_id (`str`): + Id of repository (e.g. `bigscience/bloom`). + token (`str`, `optional`): + The API token to use as HTTP bearer authorization. This is not + the authentication token. You can find the token in + https://huggingface.co/settings/token. Alternatively, you can + find both your organizations and personal API tokens using + `HfApi().whoami(token)`. + timeout (`int`): + Timeout in seconds + """ + + headers = build_hf_headers( + token=token, library_name="text-generation", library_version=__version__ + ) + + # Text Generation Inference client only supports a subset of the available hub models + if not check_model_support(repo_id, headers): + raise NotSupportedError(repo_id) + + base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" + + super(InferenceAPIClient, self).__init__( + base_url, headers=headers, timeout=timeout + ) + + +class InferenceAPIAsyncClient(AsyncClient): + """Aynschronous Client to make calls to the HuggingFace Inference API. + + Only supports a subset of the available text-generation or text2text-generation models that are served using + text-generation-inference + + Example: + + ```python + >>> from text_generation import InferenceAPIAsyncClient + + >>> client = InferenceAPIAsyncClient("bigscience/bloomz") + >>> response = await client.generate("Why is the sky blue?") + >>> response.generated_text + ' Rayleigh scattering' + + >>> result = "" + >>> async for response in client.generate_stream("Why is the sky blue?"): + >>> if not response.token.special: + >>> result += response.token.text + >>> result + ' Rayleigh scattering' + ``` + """ + + def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10): + """ + Init headers and API information + + Args: + repo_id (`str`): + Id of repository (e.g. `bigscience/bloom`). + token (`str`, `optional`): + The API token to use as HTTP bearer authorization. This is not + the authentication token. You can find the token in + https://huggingface.co/settings/token. Alternatively, you can + find both your organizations and personal API tokens using + `HfApi().whoami(token)`. + timeout (`int`): + Timeout in seconds + """ + headers = build_hf_headers( + token=token, library_name="text-generation", library_version=__version__ + ) + + # Text Generation Inference client only supports a subset of the available hub models + if not check_model_support(repo_id, headers): + raise NotSupportedError(repo_id) + + base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" + + super(InferenceAPIAsyncClient, self).__init__( + base_url, headers=headers, timeout=timeout + ) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py new file mode 100644 index 00000000..8195a08b --- /dev/null +++ b/clients/python/text_generation/types.py @@ -0,0 +1,462 @@ +from enum import Enum +from pydantic import BaseModel, field_validator +from typing import Optional, List, Union, Any + +from text_generation.errors import ValidationError + + +# enum for grammar type +class GrammarType(str, Enum): + Json = "json" + Regex = "regex" + + +# Grammar type and value +class Grammar(BaseModel): + # Grammar type + type: GrammarType + # Grammar value + value: Union[str, dict] + + +class ToolCall(BaseModel): + # Id of the tool call + id: int + # Type of the tool call + type: str + # Function details of the tool call + function: dict + + +class Message(BaseModel): + # Role of the message sender + role: str + # Content of the message + content: Optional[str] = None + # Optional name of the message sender + name: Optional[str] = None + # Tool calls associated with the chat completion + tool_calls: Optional[Any] = None + + +class Tool(BaseModel): + # Type of the tool + type: str + # Function details of the tool + function: dict + + +class Function(BaseModel): + name: Optional[str] + arguments: str + + +class ChoiceDeltaToolCall(BaseModel): + index: int + id: str + type: str + function: Function + + +class ChoiceDelta(BaseModel): + role: str + content: Optional[str] = None + tool_calls: Optional[ChoiceDeltaToolCall] + + +class Choice(BaseModel): + index: int + delta: ChoiceDelta + logprobs: Optional[dict] = None + finish_reason: Optional[str] = None + + +class CompletionRequest(BaseModel): + # Model identifier + model: str + # Prompt + prompt: str + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None + # LoRA id + lora_id: Optional[str] = None + + +class CompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + text: str + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + + +class Completion(BaseModel): + # Completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[CompletionComplete] + + +class ChatRequest(BaseModel): + # Model identifier + model: str + # List of messages in the conversation + messages: List[Message] + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] = None + # Bias values for token selection + logit_bias: Optional[List[float]] = None + # Whether to return log probabilities + logprobs: Optional[bool] = None + # Number of most likely tokens to return at each position + top_logprobs: Optional[int] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Number of chat completion choices to generate + n: Optional[int] = None + # Penalty for presence of new tokens + presence_penalty: Optional[float] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # List of tools to be used + tools: Optional[List[Tool]] = None + # A prompt to be appended before the tools + tool_prompt: Optional[str] = None + # Choice of tool to be used + tool_choice: Optional[str] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None + # LoRA id + lora_id: Optional[str] = None + + +class ChatCompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + message: Message + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + # Usage details of the chat completion + usage: Optional[Any] = None + + +class ChatComplete(BaseModel): + # Chat completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[ChatCompletionComplete] + usage: Any + + +class ChatCompletionChunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choice] + + +class Parameters(BaseModel): + # Activate logits sampling + do_sample: bool = False + # Maximum number of generated tokens + max_new_tokens: int = 20 + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] = None + # Whether to prepend the prompt to the generated text + return_full_text: bool = False + # Stop generating tokens if a member of `stop_sequences` is generated + stop: List[str] = [] + # Random sampling seed + seed: Optional[int] = None + # The value used to module the logits distribution. + temperature: Optional[float] = None + # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: Optional[int] = None + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. + top_p: Optional[float] = None + # truncate inputs tokens to the given size + truncate: Optional[int] = None + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] = None + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] = None + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + watermark: bool = False + # Get generation details + details: bool = False + # Get decoder input token logprobs and ids + decoder_input_details: bool = False + # Return the N most likely tokens at each step + top_n_tokens: Optional[int] = None + # grammar to use for generation + grammar: Optional[Grammar] = None + + @field_validator("best_of") + def valid_best_of(cls, field_value, values): + if field_value is not None: + if field_value <= 0: + raise ValidationError("`best_of` must be strictly positive") + if field_value > 1 and values.data["seed"] is not None: + raise ValidationError("`seed` must not be set when `best_of` is > 1") + sampling = ( + values.data["do_sample"] + | (values.data["temperature"] is not None) + | (values.data["top_k"] is not None) + | (values.data["top_p"] is not None) + | (values.data["typical_p"] is not None) + ) + if field_value > 1 and not sampling: + raise ValidationError("you must use sampling when `best_of` is > 1") + + return field_value + + @field_validator("repetition_penalty") + def valid_repetition_penalty(cls, v): + if v is not None and v <= 0: + raise ValidationError("`repetition_penalty` must be strictly positive") + return v + + @field_validator("frequency_penalty") + def valid_frequency_penalty(cls, v): + if v is not None and v <= 0: + raise ValidationError("`frequency_penalty` must be strictly positive") + return v + + @field_validator("seed") + def valid_seed(cls, v): + if v is not None and v < 0: + raise ValidationError("`seed` must be positive") + return v + + @field_validator("temperature") + def valid_temp(cls, v): + if v is not None and v <= 0: + raise ValidationError("`temperature` must be strictly positive") + return v + + @field_validator("top_k") + def valid_top_k(cls, v): + if v is not None and v <= 0: + raise ValidationError("`top_k` must be strictly positive") + return v + + @field_validator("top_p") + def valid_top_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + raise ValidationError("`top_p` must be > 0.0 and < 1.0") + return v + + @field_validator("truncate") + def valid_truncate(cls, v): + if v is not None and v <= 0: + raise ValidationError("`truncate` must be strictly positive") + return v + + @field_validator("typical_p") + def valid_typical_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + raise ValidationError("`typical_p` must be > 0.0 and < 1.0") + return v + + @field_validator("top_n_tokens") + def valid_top_n_tokens(cls, v): + if v is not None and v <= 0: + raise ValidationError("`top_n_tokens` must be strictly positive") + return v + + @field_validator("grammar") + def valid_grammar(cls, v): + if v is not None: + if v.type == GrammarType.Regex and not v.value: + raise ValidationError("`value` cannot be empty for `regex` grammar") + if v.type == GrammarType.Json and not v.value: + raise ValidationError("`value` cannot be empty for `json` grammar") + return v + + +class Request(BaseModel): + # Prompt + inputs: str + # Generation parameters + parameters: Optional[Parameters] = None + # Whether to stream output tokens + stream: bool = False + # LoRA id + lora_id: Optional[str] = None + + @field_validator("inputs") + def valid_input(cls, v): + if not v: + raise ValidationError("`inputs` cannot be empty") + return v + + @field_validator("stream") + def valid_best_of_stream(cls, field_value, values): + parameters = values.data["parameters"] + if ( + parameters is not None + and parameters.best_of is not None + and parameters.best_of > 1 + and field_value + ): + raise ValidationError( + "`best_of` != 1 is not supported when `stream` == True" + ) + return field_value + + +# Decoder input tokens +class InputToken(BaseModel): + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + # Optional since the logprob of the first token cannot be computed + logprob: Optional[float] = None + + +# Generated tokens +class Token(BaseModel): + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + logprob: Optional[float] = None + # Is the token a special token + # Can be used to ignore tokens when concatenating + special: bool + + +# Generation finish reason +class FinishReason(str, Enum): + # number of generated tokens == `max_new_tokens` + Length = "length" + # the model generated its end of sequence token + EndOfSequenceToken = "eos_token" + # the model generated a text included in `stop_sequences` + StopSequence = "stop_sequence" + + +# Additional sequences when using the `best_of` parameter +class BestOfSequence(BaseModel): + # Generated text + generated_text: str + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] = None + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] + # Generated tokens + tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] = None + + +# `generate` details +class Details(BaseModel): + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] = None + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] + # Generated tokens + tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] = None + # Additional sequences when using the `best_of` parameter + best_of_sequences: Optional[List[BestOfSequence]] = None + + +# `generate` return value +class Response(BaseModel): + # Generated text + generated_text: str + # Generation details + details: Details + + +# `generate_stream` details +class StreamDetails(BaseModel): + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] = None + + +# `generate_stream` return value +class StreamResponse(BaseModel): + # Generated token + token: Token + # Most likely tokens + top_tokens: Optional[List[Token]] = None + # Complete generated text + # Only available when the generation is finished + generated_text: Optional[str] = None + # Generation details + # Only available when the generation is finished + details: Optional[StreamDetails] = None + + +# Inference API currently deployed model +class DeployedModel(BaseModel): + model_id: str + sha: str diff --git a/copy_back.py b/copy_back.py deleted file mode 100644 index 520a9533..00000000 --- a/copy_back.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import print_function -import filecmp -import os.path -import sys -import shutil -import os - -compare_file_data = True - -files = [] - -def compare_dir_trees(dir1, dir2, compare_file_data, output): - def compare_dirs(dir1, dir2): - dirs_cmp = filecmp.dircmp(dir1, dir2) - if compare_file_data and dirs_cmp.diff_files: - for f in dirs_cmp.diff_files: - files.append(dir1+'/' + f) - for common_dir in dirs_cmp.common_dirs: - new_dir1 = os.path.join(dir1, common_dir) - new_dir2 = os.path.join(dir2, common_dir) - compare_dirs(new_dir1, new_dir2) - compare_dirs(dir1, dir2) - -dirs = ['server', 'clients', 'launcher', 'benchmark', 'integration-tests', 'load_tests', 'proto', 'router'] -for dir in dirs: - if os.path.exists(dir): - dir_a = 'build/' + dir - dir_b = dir - compare_dir_trees(dir_a, dir_b, compare_file_data, sys.stdout) - -for file in files: - print(file + " -> " + file.replace('build/', '')) - os.remove(file.replace('build/', '')) - shutil.copy(file, file.replace('build/', '')) diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..fb2ff198 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,10 @@ +Documentation available at: https://huggingface.co/docs/text-generation-inference + +## Release + +When making a release, please update the latest version in the documentation with: +``` +export OLD_VERSION="2\.0\.3" +export NEW_VERSION="2\.0\.4" +find . -name '*.md' -exec sed -i -e "s/$OLD_VERSION/$NEW_VERSION/g" {} \; +``` diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 00000000..f582d3ce --- /dev/null +++ b/docs/index.html @@ -0,0 +1,30 @@ + + + + + + Text Generation Inference API + + +
+ + + diff --git a/docs/openapi.json b/docs/openapi.json new file mode 100644 index 00000000..79c3b80f --- /dev/null +++ b/docs/openapi.json @@ -0,0 +1,1883 @@ +{ + "openapi": "3.0.3", + "info": { + "title": "Text Generation Inference", + "description": "Text Generation Webserver", + "contact": { + "name": "Olivier Dehaene" + }, + "license": { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + }, + "version": "2.0.1" + }, + "paths": { + "/": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", + "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", + "operationId": "compat_generate", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompatGenerateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Text", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateResponse" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/StreamResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, + "/generate": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens", + "description": "Generate tokens", + "operationId": "generate", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Text", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, + "/generate_stream": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate a stream of token using Server-Sent Events", + "description": "Generate a stream of token using Server-Sent Events", + "operationId": "generate_stream", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Text", + "content": { + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/StreamResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, + "/health": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Health check method", + "description": "Health check method", + "operationId": "health", + "responses": { + "200": { + "description": "Everything is working fine" + }, + "503": { + "description": "Text generation inference is down", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "unhealthy", + "error_type": "healthcheck" + } + } + } + } + } + } + }, + "/info": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Text Generation Inference endpoint info", + "description": "Text Generation Inference endpoint info", + "operationId": "get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Info" + } + } + } + } + } + } + }, + "/metrics": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Prometheus metrics scrape endpoint", + "description": "Prometheus metrics scrape endpoint", + "operationId": "metrics", + "responses": { + "200": { + "description": "Prometheus Metrics", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + } + } + } + }, + "/tokenize": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Tokenize inputs", + "description": "Tokenize inputs", + "operationId": "tokenize", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Tokenized ids", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TokenizeResponse" + } + } + } + }, + "404": { + "description": "No tokenizer found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "No fast tokenizer available" + } + } + } + } + } + } + }, + "/v1/chat/completions": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens", + "description": "Generate tokens", + "operationId": "chat_completions", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Chat Completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletion" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionChunk" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, + "/v1/completions": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens", + "description": "Generate tokens", + "operationId": "completions", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompletionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Chat Completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Completion" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/CompletionCompleteChunk" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "BestOfSequence": { + "type": "object", + "required": [ + "generated_text", + "finish_reason", + "generated_tokens", + "prefill", + "tokens" + ], + "properties": { + "finish_reason": { + "$ref": "#/components/schemas/FinishReason" + }, + "generated_text": { + "type": "string", + "example": "test" + }, + "generated_tokens": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, + "prefill": { + "type": "array", + "items": { + "$ref": "#/components/schemas/PrefillToken" + } + }, + "seed": { + "type": "integer", + "format": "int64", + "example": 42, + "nullable": true, + "minimum": 0 + }, + "tokens": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + }, + "top_tokens": { + "type": "array", + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + } + } + } + }, + "ChatCompletion": { + "type": "object", + "required": [ + "id", + "object", + "created", + "model", + "system_fingerprint", + "choices", + "usage" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ChatCompletionComplete" + } + }, + "created": { + "type": "integer", + "format": "int64", + "example": "1706270835", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string", + "example": "mistralai/Mistral-7B-Instruct-v0.2" + }, + "object": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + }, + "usage": { + "$ref": "#/components/schemas/Usage" + } + } + }, + "ChatCompletionChoice": { + "type": "object", + "required": [ + "index", + "delta" + ], + "properties": { + "delta": { + "$ref": "#/components/schemas/ChatCompletionDelta" + }, + "finish_reason": { + "type": "string", + "nullable": true + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "logprobs": { + "allOf": [ + { + "$ref": "#/components/schemas/ChatCompletionLogprobs" + } + ], + "nullable": true + } + } + }, + "ChatCompletionChunk": { + "type": "object", + "required": [ + "id", + "object", + "created", + "model", + "system_fingerprint", + "choices" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ChatCompletionChoice" + } + }, + "created": { + "type": "integer", + "format": "int64", + "example": "1706270978", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string", + "example": "mistralai/Mistral-7B-Instruct-v0.2" + }, + "object": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + } + } + }, + "ChatCompletionComplete": { + "type": "object", + "required": [ + "index", + "message", + "finish_reason" + ], + "properties": { + "finish_reason": { + "type": "string" + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "logprobs": { + "allOf": [ + { + "$ref": "#/components/schemas/ChatCompletionLogprobs" + } + ], + "nullable": true + }, + "message": { + "$ref": "#/components/schemas/Message" + } + } + }, + "ChatCompletionDelta": { + "type": "object", + "required": [ + "role" + ], + "properties": { + "content": { + "type": "string", + "example": "What is Deep Learning?", + "nullable": true + }, + "role": { + "type": "string", + "example": "user" + }, + "tool_calls": { + "allOf": [ + { + "$ref": "#/components/schemas/DeltaToolCall" + } + ], + "nullable": true + } + } + }, + "ChatCompletionLogprob": { + "type": "object", + "required": [ + "token", + "logprob", + "top_logprobs" + ], + "properties": { + "logprob": { + "type": "number", + "format": "float" + }, + "token": { + "type": "string" + }, + "top_logprobs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ChatCompletionTopLogprob" + } + } + } + }, + "ChatCompletionLogprobs": { + "type": "object", + "required": [ + "content" + ], + "properties": { + "content": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ChatCompletionLogprob" + } + } + } + }, + "ChatCompletionTopLogprob": { + "type": "object", + "required": [ + "token", + "logprob" + ], + "properties": { + "logprob": { + "type": "number", + "format": "float" + }, + "token": { + "type": "string" + } + } + }, + "ChatRequest": { + "type": "object", + "required": [ + "model", + "messages" + ], + "properties": { + "frequency_penalty": { + "type": "number", + "format": "float", + "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", + "example": "1.0", + "nullable": true + }, + "logit_bias": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", + "nullable": true + }, + "logprobs": { + "type": "boolean", + "description": "Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\noutput token returned in the content of message.", + "example": "false", + "nullable": true + }, + "max_tokens": { + "type": "integer", + "format": "int32", + "description": "The maximum number of tokens that can be generated in the chat completion.", + "example": "32", + "nullable": true, + "minimum": 0 + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Message" + }, + "description": "A list of messages comprising the conversation so far.", + "example": "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]" + }, + "model": { + "type": "string", + "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + "example": "mistralai/Mistral-7B-Instruct-v0.2" + }, + "n": { + "type": "integer", + "format": "int32", + "description": "UNUSED\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.", + "example": "2", + "nullable": true, + "minimum": 0 + }, + "presence_penalty": { + "type": "number", + "format": "float", + "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics", + "example": 0.1, + "nullable": true + }, + "seed": { + "type": "integer", + "format": "int64", + "example": 42, + "nullable": true, + "minimum": 0 + }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true + }, + "stream": { + "type": "boolean" + }, + "temperature": { + "type": "number", + "format": "float", + "description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.", + "example": 1.0, + "nullable": true + }, + "tool_choice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, + "tool_prompt": { + "type": "string", + "description": "A prompt to be appended before the tools", + "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "nullable": true + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Tool" + }, + "description": "A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of\nfunctions the model may generate JSON inputs for.", + "example": "null", + "nullable": true + }, + "top_logprobs": { + "type": "integer", + "format": "int32", + "description": "An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.", + "example": "5", + "nullable": true, + "minimum": 0 + }, + "top_p": { + "type": "number", + "format": "float", + "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", + "example": 0.95, + "nullable": true + } + } + }, + "CompatGenerateRequest": { + "type": "object", + "required": [ + "inputs" + ], + "properties": { + "inputs": { + "type": "string", + "example": "My name is Olivier and I" + }, + "parameters": { + "$ref": "#/components/schemas/GenerateParameters" + }, + "stream": { + "type": "boolean", + "default": "false" + } + } + }, + "CompletionComplete": { + "type": "object", + "required": [ + "index", + "text", + "finish_reason" + ], + "properties": { + "finish_reason": { + "type": "string" + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "logprobs": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "nullable": true + }, + "text": { + "type": "string" + } + } + }, + "CompletionCompleteChunk": { + "type": "object", + "required": [ + "id", + "object", + "created", + "choices", + "model", + "system_fingerprint" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CompletionComplete" + } + }, + "created": { + "type": "integer", + "format": "int64", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "object": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + } + } + }, + "CompletionRequest": { + "type": "object", + "required": [ + "model", + "prompt" + ], + "properties": { + "frequency_penalty": { + "type": "number", + "format": "float", + "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", + "example": "1.0", + "nullable": true + }, + "max_tokens": { + "type": "integer", + "format": "int32", + "description": "The maximum number of tokens that can be generated in the chat completion.", + "default": "32", + "nullable": true, + "minimum": 0 + }, + "model": { + "type": "string", + "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + "example": "mistralai/Mistral-7B-Instruct-v0.2" + }, + "prompt": { + "type": "array", + "items": { + "type": "string" + }, + "description": "The prompt to generate completions for.", + "example": "What is Deep Learning?" + }, + "repetition_penalty": { + "type": "number", + "format": "float", + "nullable": true + }, + "seed": { + "type": "integer", + "format": "int64", + "example": 42, + "nullable": true, + "minimum": 0 + }, + "stream": { + "type": "boolean" + }, + "suffix": { + "type": "string", + "description": "The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.\nplease see the completion_template field in the model's tokenizer_config.json file for completion template.", + "nullable": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.", + "example": 1.0, + "nullable": true + }, + "top_p": { + "type": "number", + "format": "float", + "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", + "example": 0.95, + "nullable": true + }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true + } + } + }, + "DeltaToolCall": { + "type": "object", + "required": [ + "index", + "id", + "type", + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/Function" + }, + "id": { + "type": "string" + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "type": { + "type": "string" + } + } + }, + "Details": { + "type": "object", + "required": [ + "finish_reason", + "generated_tokens", + "prefill", + "tokens" + ], + "properties": { + "best_of_sequences": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BestOfSequence" + }, + "nullable": true + }, + "finish_reason": { + "$ref": "#/components/schemas/FinishReason" + }, + "generated_tokens": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, + "prefill": { + "type": "array", + "items": { + "$ref": "#/components/schemas/PrefillToken" + } + }, + "seed": { + "type": "integer", + "format": "int64", + "example": 42, + "nullable": true, + "minimum": 0 + }, + "tokens": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + }, + "top_tokens": { + "type": "array", + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + } + } + } + }, + "ErrorResponse": { + "type": "object", + "required": [ + "error", + "error_type" + ], + "properties": { + "error": { + "type": "string" + }, + "error_type": { + "type": "string" + } + } + }, + "FinishReason": { + "type": "string", + "enum": [ + "length", + "eos_token", + "stop_sequence" + ], + "example": "Length" + }, + "Function": { + "type": "object", + "required": [ + "arguments" + ], + "properties": { + "arguments": { + "type": "string" + }, + "name": { + "type": "string", + "nullable": true + } + } + }, + "FunctionDefinition": { + "type": "object", + "required": [ + "name", + "arguments" + ], + "properties": { + "arguments": {}, + "description": { + "type": "string", + "nullable": true + }, + "name": { + "type": "string" + } + } + }, + "GenerateParameters": { + "type": "object", + "properties": { + "best_of": { + "type": "integer", + "default": "null", + "example": 1, + "nullable": true, + "minimum": 0, + "exclusiveMinimum": 0 + }, + "decoder_input_details": { + "type": "boolean", + "default": "false" + }, + "details": { + "type": "boolean", + "default": "true" + }, + "do_sample": { + "type": "boolean", + "default": "false", + "example": true + }, + "frequency_penalty": { + "type": "number", + "format": "float", + "default": "null", + "example": 0.1, + "nullable": true, + "exclusiveMinimum": -2 + }, + "grammar": { + "allOf": [ + { + "$ref": "#/components/schemas/GrammarType" + } + ], + "default": "null", + "nullable": true + }, + "max_new_tokens": { + "type": "integer", + "format": "int32", + "default": "100", + "example": "20", + "nullable": true, + "minimum": 0 + }, + "repetition_penalty": { + "type": "number", + "format": "float", + "default": "null", + "example": 1.03, + "nullable": true, + "exclusiveMinimum": 0 + }, + "return_full_text": { + "type": "boolean", + "default": "null", + "example": false, + "nullable": true + }, + "seed": { + "type": "integer", + "format": "int64", + "default": "null", + "example": "null", + "nullable": true, + "minimum": 0, + "exclusiveMinimum": 0 + }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "photographer" + ], + "maxItems": 4 + }, + "temperature": { + "type": "number", + "format": "float", + "default": "null", + "example": 0.5, + "nullable": true, + "exclusiveMinimum": 0 + }, + "top_k": { + "type": "integer", + "format": "int32", + "default": "null", + "example": 10, + "nullable": true, + "exclusiveMinimum": 0 + }, + "top_n_tokens": { + "type": "integer", + "format": "int32", + "default": "null", + "example": 5, + "nullable": true, + "minimum": 0, + "exclusiveMinimum": 0 + }, + "top_p": { + "type": "number", + "format": "float", + "default": "null", + "example": 0.95, + "nullable": true, + "maximum": 1, + "exclusiveMinimum": 0 + }, + "truncate": { + "type": "integer", + "default": "null", + "example": "null", + "nullable": true, + "minimum": 0 + }, + "typical_p": { + "type": "number", + "format": "float", + "default": "null", + "example": 0.95, + "nullable": true, + "maximum": 1, + "exclusiveMinimum": 0 + }, + "watermark": { + "type": "boolean", + "default": "false", + "example": true + } + } + }, + "GenerateRequest": { + "type": "object", + "required": [ + "inputs" + ], + "properties": { + "inputs": { + "type": "string", + "example": "My name is Olivier and I" + }, + "parameters": { + "$ref": "#/components/schemas/GenerateParameters" + } + } + }, + "GenerateResponse": { + "type": "object", + "required": [ + "generated_text" + ], + "properties": { + "details": { + "allOf": [ + { + "$ref": "#/components/schemas/Details" + } + ], + "nullable": true + }, + "generated_text": { + "type": "string", + "example": "test" + } + } + }, + "GrammarType": { + "oneOf": [ + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "json" + ] + }, + "value": { + "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions." + } + } + }, + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "regex" + ] + }, + "value": { + "type": "string" + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, + "Info": { + "type": "object", + "required": [ + "model_id", + "model_dtype", + "model_device_type", + "max_concurrent_requests", + "max_best_of", + "max_stop_sequences", + "max_input_length", + "max_total_tokens", + "waiting_served_ratio", + "max_batch_total_tokens", + "max_waiting_tokens", + "validation_workers", + "max_client_batch_size", + "version" + ], + "properties": { + "docker_label": { + "type": "string", + "example": "null", + "nullable": true + }, + "max_batch_size": { + "type": "integer", + "example": "null", + "nullable": true, + "minimum": 0 + }, + "max_batch_total_tokens": { + "type": "integer", + "format": "int32", + "example": "32000", + "minimum": 0 + }, + "max_best_of": { + "type": "integer", + "example": "2", + "minimum": 0 + }, + "max_client_batch_size": { + "type": "integer", + "example": "32", + "minimum": 0 + }, + "max_concurrent_requests": { + "type": "integer", + "description": "Router Parameters", + "example": "128", + "minimum": 0 + }, + "max_input_length": { + "type": "integer", + "example": "1024", + "minimum": 0 + }, + "max_stop_sequences": { + "type": "integer", + "example": "4", + "minimum": 0 + }, + "max_total_tokens": { + "type": "integer", + "example": "2048", + "minimum": 0 + }, + "max_waiting_tokens": { + "type": "integer", + "example": "20", + "minimum": 0 + }, + "model_device_type": { + "type": "string", + "example": "cuda" + }, + "model_dtype": { + "type": "string", + "example": "torch.float16" + }, + "model_id": { + "type": "string", + "description": "Model info", + "example": "bigscience/blomm-560m" + }, + "model_pipeline_tag": { + "type": "string", + "example": "text-generation", + "nullable": true + }, + "model_sha": { + "type": "string", + "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", + "nullable": true + }, + "sha": { + "type": "string", + "example": "null", + "nullable": true + }, + "validation_workers": { + "type": "integer", + "example": "2", + "minimum": 0 + }, + "version": { + "type": "string", + "description": "Router Info", + "example": "0.5.0" + }, + "waiting_served_ratio": { + "type": "number", + "format": "float", + "example": "1.2" + } + } + }, + "Message": { + "type": "object", + "required": [ + "role" + ], + "properties": { + "content": { + "type": "string", + "example": "My name is David and I", + "nullable": true + }, + "name": { + "type": "string", + "example": "\"David\"", + "nullable": true + }, + "role": { + "type": "string", + "example": "user" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "nullable": true + } + } + }, + "PrefillToken": { + "type": "object", + "required": [ + "id", + "text", + "logprob" + ], + "properties": { + "id": { + "type": "integer", + "format": "int32", + "example": 0, + "minimum": 0 + }, + "logprob": { + "type": "number", + "format": "float", + "example": -0.34, + "nullable": true + }, + "text": { + "type": "string", + "example": "test" + } + } + }, + "SimpleToken": { + "type": "object", + "required": [ + "id", + "text", + "start", + "stop" + ], + "properties": { + "id": { + "type": "integer", + "format": "int32", + "example": 0, + "minimum": 0 + }, + "start": { + "type": "integer", + "example": 0, + "minimum": 0 + }, + "stop": { + "type": "integer", + "example": 2, + "minimum": 0 + }, + "text": { + "type": "string", + "example": "test" + } + } + }, + "StreamDetails": { + "type": "object", + "required": [ + "finish_reason", + "generated_tokens" + ], + "properties": { + "finish_reason": { + "$ref": "#/components/schemas/FinishReason" + }, + "generated_tokens": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, + "seed": { + "type": "integer", + "format": "int64", + "example": 42, + "nullable": true, + "minimum": 0 + } + } + }, + "StreamResponse": { + "type": "object", + "required": [ + "index", + "token" + ], + "properties": { + "details": { + "allOf": [ + { + "$ref": "#/components/schemas/StreamDetails" + } + ], + "default": "null", + "nullable": true + }, + "generated_text": { + "type": "string", + "default": "null", + "example": "test", + "nullable": true + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "token": { + "$ref": "#/components/schemas/Token" + }, + "top_tokens": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + } + } + }, + "Token": { + "type": "object", + "required": [ + "id", + "text", + "logprob", + "special" + ], + "properties": { + "id": { + "type": "integer", + "format": "int32", + "example": 0, + "minimum": 0 + }, + "logprob": { + "type": "number", + "format": "float", + "example": -0.34, + "nullable": true + }, + "special": { + "type": "boolean", + "example": "false" + }, + "text": { + "type": "string", + "example": "test" + } + } + }, + "TokenizeResponse": { + "type": "array", + "items": { + "$ref": "#/components/schemas/SimpleToken" + } + }, + "Tool": { + "type": "object", + "required": [ + "type", + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionDefinition" + }, + "type": { + "type": "string", + "example": "function" + } + } + }, + "ToolCall": { + "type": "object", + "required": [ + "id", + "type", + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionDefinition" + }, + "id": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "type": { + "type": "string" + } + } + }, + "ToolType": { + "oneOf": [ + { + "type": "object", + "required": [ + "FunctionName" + ], + "properties": { + "FunctionName": { + "type": "string" + } + } + }, + { + "type": "string", + "enum": [ + "OneOf" + ] + } + ] + }, + "Usage": { + "type": "object", + "required": [ + "prompt_tokens", + "completion_tokens", + "total_tokens" + ], + "properties": { + "completion_tokens": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "prompt_tokens": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "total_tokens": { + "type": "integer", + "format": "int32", + "minimum": 0 + } + } + } + } + }, + "tags": [ + { + "name": "Text Generation Inference", + "description": "Hugging Face Text Generation Inference API" + } + ] +} diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml new file mode 100644 index 00000000..a7351a33 --- /dev/null +++ b/docs/source/_toctree.yml @@ -0,0 +1,63 @@ +- sections: + - local: index + title: Text Generation Inference + - local: quicktour + title: Quick Tour + - local: installation_nvidia + title: Using TGI with Nvidia GPUs + - local: installation_amd + title: Using TGI with AMD GPUs + - local: installation_gaudi + title: Using TGI with Intel Gaudi + - local: installation_inferentia + title: Using TGI with AWS Inferentia + - local: installation + title: Installation from source + - local: supported_models + title: Supported Models and Hardware + - local: messages_api + title: Messages API + title: Getting started +- sections: + - local: basic_tutorials/consuming_tgi + title: Consuming TGI + - local: basic_tutorials/preparing_model + title: Preparing Model for Serving + - local: basic_tutorials/gated_model_access + title: Serving Private & Gated Models + - local: basic_tutorials/using_cli + title: Using TGI CLI + - local: basic_tutorials/launcher + title: All TGI CLI options + - local: basic_tutorials/non_core_models + title: Non-core Model Serving + - local: basic_tutorials/safety + title: Safety + - local: basic_tutorials/using_guidance + title: Using Guidance, JSON, tools + - local: basic_tutorials/visual_language_models + title: Visual Language Models + - local: basic_tutorials/monitoring + title: Monitoring TGI with Prometheus and Grafana + - local: basic_tutorials/train_medusa + title: Train Medusa + title: Tutorials +- sections: + - local: conceptual/streaming + title: Streaming + - local: conceptual/quantization + title: Quantization + - local: conceptual/tensor_parallelism + title: Tensor Parallelism + - local: conceptual/paged_attention + title: PagedAttention + - local: conceptual/safetensors + title: Safetensors + - local: conceptual/flash_attention + title: Flash Attention + - local: conceptual/speculation + title: Speculation (Medusa, ngram) + - local: conceptual/guidance + title: How Guidance Works (via outlines) + + title: Conceptual Guides diff --git a/docs/source/basic_tutorials/consuming_tgi.md b/docs/source/basic_tutorials/consuming_tgi.md new file mode 100644 index 00000000..4829ec7c --- /dev/null +++ b/docs/source/basic_tutorials/consuming_tgi.md @@ -0,0 +1,155 @@ +# Consuming Text Generation Inference + +There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models. + +## curl + +After the launch, you can query the model using either the `/generate` or `/generate_stream` routes: + +```bash +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -H 'Content-Type: application/json' +``` + + +## Inference Client + +[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface. +You can simply install `huggingface-hub` package with pip. + +```bash +pip install huggingface-hub +``` + +Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python. + +```python +from huggingface_hub import InferenceClient + +client = InferenceClient(model="http://127.0.0.1:8080") +client.text_generation(prompt="Write a code for snake game") +``` + +You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows: + +```python +for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): + print(token) +``` + +Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream. + +```python +output = client.text_generation(prompt="Meaning of life is", details=True) +print(output) + +# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..)) +``` + +You can see how to stream below. + +```python +output = client.text_generation(prompt="Meaning of life is", stream=True, details=True) +print(next(iter(output))) + +# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None) +``` + +You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + + +## ChatUI + +ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces. + +To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served. + +``` +{ +// rest of the model config here +"endpoints": [{"url": "https://HOST:PORT/generate_stream"}] +} +``` + +![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png) + +## Gradio + +Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first. + +```bash +pip install huggingface-hub gradio +``` + +Assume you are serving your model on port 8080, we will query through [InferenceClient](consuming_tgi#inference-client). + +```python +import gradio as gr +from huggingface_hub import InferenceClient + +client = InferenceClient(model="http://127.0.0.1:8080") + +def inference(message, history): + partial_message = "" + for token in client.text_generation(message, max_new_tokens=20, stream=True): + partial_message += token + yield partial_message + +gr.ChatInterface( + inference, + chatbot=gr.Chatbot(height=300), + textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7), + description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", + title="Gradio 🤝 TGI", + examples=["Are tomatoes vegetables?"], + retry_btn="Retry", + undo_btn="Undo", + clear_btn="Clear", +).queue().launch() +``` + +The UI looks like this 👇 + +
+ + +
+ +You can try the demo directly here 👇 + +
+ +
+ + + +You can disable streaming mode using `return` instead of `yield` in your inference function, like below. + +```python +def inference(message, history): + return client.text_generation(message, max_new_tokens=20) +``` + +You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast). + +## API documentation + +You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference). diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md new file mode 100644 index 00000000..b49c59c9 --- /dev/null +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -0,0 +1,24 @@ +# Serving Private & Gated Models + +If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens) + +If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable. For example: + +``` +export HUGGING_FACE_HUB_TOKEN= +``` + +If you would like to do it through Docker, you can provide your token by specifying `HUGGING_FACE_HUB_TOKEN` as shown below. + +```bash +model=meta-llama/Llama-2-7b-chat-hf +volume=$PWD/data +token= + +docker run --gpus all \ + --shm-size 1g \ + -e HUGGING_FACE_HUB_TOKEN=$token \ + -p 8080:80 \ + -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \ + --model-id $model +``` diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md new file mode 100644 index 00000000..08a03d0d --- /dev/null +++ b/docs/source/basic_tutorials/launcher.md @@ -0,0 +1,432 @@ +# Text-generation-launcher arguments + + + +```shell +Text Generation Launcher + +Usage: text-generation-launcher [OPTIONS] + +Options: +``` +## MODEL_ID +```shell + --model-id + The name of the model to load. Can be a MODEL_ID as listed on like `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers + + [env: MODEL_ID=] + [default: bigscience/bloom-560m] + +``` +## REVISION +```shell + --revision + The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2` + + [env: REVISION=] + +``` +## VALIDATION_WORKERS +```shell + --validation-workers + The number of tokenizer workers used for payload validation and truncation inside the router + + [env: VALIDATION_WORKERS=] + [default: 2] + +``` +## SHARDED +```shell + --sharded + Whether to shard the model across multiple GPUs By default text-generation-inference will use all available GPUs to run the model. Setting it to `false` deactivates `num_shard` + + [env: SHARDED=] + [possible values: true, false] + +``` +## NUM_SHARD +```shell + --num-shard + The number of shards to use if you don't want to use all GPUs on a given machine. You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2` and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance + + [env: NUM_SHARD=] + +``` +## QUANTIZE +```shell + --quantize + Whether you want the model to be quantized + + [env: QUANTIZE=] + + Possible values: + - awq: 4 bit quantization. Requires a specific AWQ quantized model: . Should replace GPTQ models wherever possible because of the better latency + - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from + - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) + - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels + - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 + - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 + - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model + - fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations + +``` +## SPECULATE +```shell + --speculate + The number of input_ids to speculate on If using a medusa model, the heads will be picked up automatically Other wise, it will use n-gram speculation which is relatively free in terms of compute, but the speedup heavily depends on the task + + [env: SPECULATE=] + +``` +## DTYPE +```shell + --dtype + The dtype to be forced upon the model. This option cannot be used with `--quantize` + + [env: DTYPE=] + [possible values: float16, bfloat16] + +``` +## TRUST_REMOTE_CODE +```shell + --trust-remote-code + Whether you want to execute hub modelling code. Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision + + [env: TRUST_REMOTE_CODE=] + +``` +## MAX_CONCURRENT_REQUESTS +```shell + --max-concurrent-requests + The maximum amount of concurrent requests for this particular deployment. Having a low limit will refuse clients requests instead of having them wait for too long and is usually good to handle backpressure correctly + + [env: MAX_CONCURRENT_REQUESTS=] + [default: 128] + +``` +## MAX_BEST_OF +```shell + --max-best-of + This is the maximum allowed value for clients to set `best_of`. Best of makes `n` generations at the same time, and return the best in terms of overall log probability over the entire generated sequence + + [env: MAX_BEST_OF=] + [default: 2] + +``` +## MAX_STOP_SEQUENCES +```shell + --max-stop-sequences + This is the maximum allowed value for clients to set `stop_sequences`. Stop sequences are used to allow the model to stop on more than just the EOS token, and enable more complex "prompting" where users can preprompt the model in a specific way and define their "own" stop token aligned with their prompt + + [env: MAX_STOP_SEQUENCES=] + [default: 4] + +``` +## MAX_TOP_N_TOKENS +```shell + --max-top-n-tokens + This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking + + [env: MAX_TOP_N_TOKENS=] + [default: 5] + +``` +## MAX_INPUT_TOKENS +```shell + --max-input-tokens + This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095) + + [env: MAX_INPUT_TOKENS=] + +``` +## MAX_INPUT_LENGTH +```shell + --max-input-length + Legacy version of [`Args::max_input_tokens`] + + [env: MAX_INPUT_LENGTH=] + +``` +## MAX_TOTAL_TOKENS +```shell + --max-total-tokens + This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096) + + [env: MAX_TOTAL_TOKENS=] + +``` +## WAITING_SERVED_RATIO +```shell + --waiting-served-ratio + This represents the ratio of waiting queries vs running queries where you want to start considering pausing the running queries to include the waiting ones into the same batch. `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's only 10 queries left in the current batch we check if we can fit those 12 waiting queries into the batching strategy, and if yes, then batching happens delaying the 10 running queries by a `prefill` run. + + This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`. + + [env: WAITING_SERVED_RATIO=] + [default: 0.3] + +``` +## MAX_BATCH_PREFILL_TOKENS +```shell + --max-batch-prefill-tokens + Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to `max_input_tokens + 50` to give a bit of room + + [env: MAX_BATCH_PREFILL_TOKENS=] + +``` +## MAX_BATCH_TOTAL_TOKENS +```shell + --max-batch-total-tokens + **IMPORTANT** This is one critical control to allow maximum usage of the available hardware. + + This represents the total amount of potential tokens within a batch. When using padding (not recommended) this would be equivalent of `batch_size` * `max_total_tokens`. + + However in the non-padded (flash attention) version this can be much finer. + + For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` or a single query of `1000` tokens. + + Overall this number should be the largest possible amount that fits the remaining memory (after the model is loaded). Since the actual memory overhead depends on other parameters like if you're using quantization, flash attention or the model implementation, text-generation-inference cannot infer this number automatically. + + [env: MAX_BATCH_TOTAL_TOKENS=] + +``` +## MAX_WAITING_TOKENS +```shell + --max-waiting-tokens + This setting defines how many tokens can be passed before forcing the waiting queries to be put on the batch (if the size of the batch allows for it). New queries require 1 `prefill` forward, which is different from `decode` and therefore you need to pause the running batch in order to run `prefill` to create the correct values for the waiting queries to be able to join the batch. + + With a value too small, queries will always "steal" the compute to run `prefill` and running queries will be delayed by a lot. + + With a value too big, waiting queries could wait for a very long time before being allowed a slot in the running batch. If your server is busy that means that requests that could run in ~2s on an empty server could end up running in ~20s because the query had to wait for 18s. + + This number is expressed in number of tokens to make it a bit more "model" agnostic, but what should really matter is the overall latency for end users. + + [env: MAX_WAITING_TOKENS=] + [default: 20] + +``` +## MAX_BATCH_SIZE +```shell + --max-batch-size + Enforce a maximum number of requests per batch Specific flag for hardware targets that do not support unpadded inference + + [env: MAX_BATCH_SIZE=] + +``` +## CUDA_GRAPHS +```shell + --cuda-graphs + Specify the batch sizes to compute cuda graphs for. Use "0" to disable. Default = "1,2,4,8,16,32" + + [env: CUDA_GRAPHS=] + +``` +## HOSTNAME +```shell + --hostname + The IP address to listen on + + [env: HOSTNAME=] + [default: 0.0.0.0] + +``` +## PORT +```shell + -p, --port + The port to listen on + + [env: PORT=] + [default: 3000] + +``` +## SHARD_UDS_PATH +```shell + --shard-uds-path + The name of the socket for gRPC communication between the webserver and the shards + + [env: SHARD_UDS_PATH=] + [default: /tmp/text-generation-server] + +``` +## MASTER_ADDR +```shell + --master-addr + The address the master shard will listen on. (setting used by torch distributed) + + [env: MASTER_ADDR=] + [default: localhost] + +``` +## MASTER_PORT +```shell + --master-port + The address the master port will listen on. (setting used by torch distributed) + + [env: MASTER_PORT=] + [default: 29500] + +``` +## HUGGINGFACE_HUB_CACHE +```shell + --huggingface-hub-cache + The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance + + [env: HUGGINGFACE_HUB_CACHE=] + +``` +## WEIGHTS_CACHE_OVERRIDE +```shell + --weights-cache-override + The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance + + [env: WEIGHTS_CACHE_OVERRIDE=] + +``` +## DISABLE_CUSTOM_KERNELS +```shell + --disable-custom-kernels + For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on A100. Use this flag to disable them if you're running on different hardware and encounter issues + + [env: DISABLE_CUSTOM_KERNELS=] + +``` +## CUDA_MEMORY_FRACTION +```shell + --cuda-memory-fraction + Limit the CUDA available memory. The allowed value equals the total visible memory multiplied by cuda-memory-fraction + + [env: CUDA_MEMORY_FRACTION=] + [default: 1.0] + +``` +## ROPE_SCALING +```shell + --rope-scaling + Rope scaling will only be used for RoPE models and allow rescaling the position rotary to accomodate for larger prompts. + + Goes together with `rope_factor`. + + `--rope-factor 2.0` gives linear scaling with a factor of 2.0 `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0 `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed basically) + + `--rope-scaling linear --rope-factor` fully describes the scaling you want + + [env: ROPE_SCALING=] + [possible values: linear, dynamic] + +``` +## ROPE_FACTOR +```shell + --rope-factor + Rope scaling will only be used for RoPE models See `rope_scaling` + + [env: ROPE_FACTOR=] + +``` +## JSON_OUTPUT +```shell + --json-output + Outputs the logs in JSON format (useful for telemetry) + + [env: JSON_OUTPUT=] + +``` +## OTLP_ENDPOINT +```shell + --otlp-endpoint + [env: OTLP_ENDPOINT=] + +``` +## CORS_ALLOW_ORIGIN +```shell + --cors-allow-origin + [env: CORS_ALLOW_ORIGIN=] + +``` +## WATERMARK_GAMMA +```shell + --watermark-gamma + [env: WATERMARK_GAMMA=] + +``` +## WATERMARK_DELTA +```shell + --watermark-delta + [env: WATERMARK_DELTA=] + +``` +## NGROK +```shell + --ngrok + Enable ngrok tunneling + + [env: NGROK=] + +``` +## NGROK_AUTHTOKEN +```shell + --ngrok-authtoken + ngrok authentication token + + [env: NGROK_AUTHTOKEN=] + +``` +## NGROK_EDGE +```shell + --ngrok-edge + ngrok edge + + [env: NGROK_EDGE=] + +``` +## TOKENIZER_CONFIG_PATH +```shell + --tokenizer-config-path + The path to the tokenizer config file. This path is used to load the tokenizer configuration which may include a `chat_template`. If not provided, the default config will be used from the model hub + + [env: TOKENIZER_CONFIG_PATH=] + +``` +## DISABLE_GRAMMAR_SUPPORT +```shell + --disable-grammar-support + Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar + + [env: DISABLE_GRAMMAR_SUPPORT=] + +``` +## ENV +```shell + -e, --env + Display a lot of information about your runtime environment + +``` +## MAX_CLIENT_BATCH_SIZE +```shell + --max-client-batch-size + Control the maximum number of inputs that a client can send in a single request + + [env: MAX_CLIENT_BATCH_SIZE=] + [default: 4] + +``` +## LORA_IDS +```shell + --lora-ids + Specify LoRA adapters + + [env: LORA_IDS=] + [default: empty] + +``` +## HELP +```shell + -h, --help + Print help (see a summary with '-h') + +``` +## VERSION +```shell + -V, --version + Print version + +``` diff --git a/docs/source/basic_tutorials/monitoring.md b/docs/source/basic_tutorials/monitoring.md new file mode 100644 index 00000000..509b0aff --- /dev/null +++ b/docs/source/basic_tutorials/monitoring.md @@ -0,0 +1,75 @@ +# Monitoring TGI server with Prometheus and Grafana dashboard + +TGI server deployment can easily be monitored through a Grafana dashboard, consuming a Prometheus data collection. Example of inspectable metrics are statistics on the effective batch sizes used by TGI, prefill/decode latencies, number of generated tokens, etc. + +In this tutorial, we look at how to set up a local Grafana dashboard to monitor TGI usage. + +![Grafana dashboard for TGI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/grafana.png) + +## Setup on the server machine + +First, on your server machine, TGI needs to be launched as usual. TGI exposes [multiple](https://github.com/huggingface/text-generation-inference/discussions/1127#discussioncomment-7240527) metrics that can be collected by Prometheus monitoring server. + +In the rest of this tutorial, we assume that TGI was launched through Docker with `--network host`. + +On the server where TGI is hosted, a Prometheus server needs to be installed and launched. To do so, please follow [Prometheus installation instructions](https://prometheus.io/download/#prometheus). For example, at the time of writing on a Linux machine: + +``` +wget https://github.com/prometheus/prometheus/releases/download/v2.52.0/prometheus-2.52.0.linux-amd64.tar.gz +tar -xvzf prometheus-2.52.0.linux-amd64.tar.gz +cd prometheus +``` + +Prometheus needs to be configured to listen on TGI's port. To do so, in Prometheus configuration file `prometheus.yml`, one needs to edit the lines: +``` + static_configs: + - targets: ["0.0.0.0:80"] +``` +to use the correct IP address and port. + +We suggest to try `curl 0.0.0.0:80/generate -X POST -d '{"inputs":"hey chatbot, how are","parameters":{"max_new_tokens":15}}' -H 'Content-Type: application/json'` on the server side to make sure to configure the correct IP and port. + +Once Prometheus is configured, Prometheus server can be launched on the same machine where TGI is launched: +``` +./prometheus --config.file="prometheus.yml" +``` + +In this guide, Prometheus monitoring data will be consumed on a local computer. Hence, we need to forward Prometheus port (by default 9090) to the local computer. To do so, we can for example: +* Use ssh [local port forwarding](https://www.ssh.com/academy/ssh/tunneling-example) +* Use ngrok port tunneling + +For simplicity, we will use [Ngrok](https://ngrok.com/docs/) in this guide to tunnel Prometheus port from the TGI server to the outside word. + +For that, you should follow the steps at https://dashboard.ngrok.com/get-started/setup/linux, and once Ngrok is installed, use: +```bash +ngrok http http://0.0.0.0:9090 +``` + +As a sanity check, one can make sure that Prometheus server can be accessed at the URL given by Ngrok (in the style of https://d661-4-223-164-145.ngrok-free.app) from a local machine. + +## Setup on the monitoring machine + +Monitoring is typically done on an other machine than the server one. We use a Grafana dashboard to monitor TGI's server usage. + +Two options are available: +* Use Grafana Cloud for an hosted dashboard solution (https://grafana.com/products/cloud/). +* Self-host a grafana dashboard. + +In this tutorial, for simplicity, we will self host the dashbard. We recommend installing Grafana Open-source edition following [the official install instructions](https://grafana.com/grafana/download?platform=linux&edition=oss), using the available Linux binaries. For example: + +```bash +wget https://dl.grafana.com/oss/release/grafana-11.0.0.linux-amd64.tar.gz +tar -zxvf grafana-11.0.0.linux-amd64.tar.gz +cd grafana-11.0.0 +./bin/grafana-server +``` + +Once the Grafana server is launched, the Grafana interface is available at http://localhost:3000. One needs to log in with the `admin` username and `admin` password. + +Once logged in, the Prometheus data source for Grafana needs to be configured, in the option `Add your first data source`. There, a Prometheus data source needs to be added with the Ngrok address we got earlier, that exposes Prometheus port (example: https://d661-4-223-164-145.ngrok-free.app). + +Once Prometheus data source is configured, we can finally create our dashboard! From home, go to `Create your first dashboard` and then `Import dashboard`. There, we will use the recommended dashboard template [tgi_grafana.json](https://github.com/huggingface/text-generation-inference/blob/main/assets/tgi_grafana.json) for a dashboard ready to be used, but you may configure your own dashboard as you like. + +Community contributed dashboard templates are also available, for example [here](https://grafana.com/grafana/dashboards/19831-text-generation-inference-dashboard/) or [here](https://grafana.com/grafana/dashboards/20246-text-generation-inference/). + +Load your dashboard configuration, and your TGI dashboard should be ready to go! diff --git a/docs/source/basic_tutorials/non_core_models.md b/docs/source/basic_tutorials/non_core_models.md new file mode 100644 index 00000000..2badaff0 --- /dev/null +++ b/docs/source/basic_tutorials/non_core_models.md @@ -0,0 +1,24 @@ +# Non-core Model Serving + +TGI supports various LLM architectures (see full list [here](../supported_models)). If you wish to serve a model that is not one of the supported models, TGI will fallback to the `transformers` implementation of that model. This means you will be unable to use some of the features introduced by TGI, such as tensor-parallel sharding or flash attention. However, you can still get many benefits of TGI, such as continuous batching or streaming outputs. + +You can serve these models using the same Docker command-line invocation as with fully supported models 👇 + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id gpt2 +``` + +If the model you wish to serve is a custom transformers model, and its weights and implementation are available in the Hub, you can still serve the model by passing the `--trust-remote-code` flag to the `docker run` command like below 👇 + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id --trust-remote-code +``` + +Finally, if the model is not on Hugging Face Hub but on your local, you can pass the path to the folder that contains your model like below 👇 + +```bash +# Make sure your model is in the $volume directory +docker run --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id /data/ +``` + +You can refer to [transformers docs on custom models](https://huggingface.co/docs/transformers/main/en/custom_models) for more information. diff --git a/docs/source/basic_tutorials/preparing_model.md b/docs/source/basic_tutorials/preparing_model.md new file mode 100644 index 00000000..71ca5598 --- /dev/null +++ b/docs/source/basic_tutorials/preparing_model.md @@ -0,0 +1,22 @@ +# Preparing the Model + +Text Generation Inference improves the model in several aspects. + +## Quantization + +TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) + + +## RoPE Scaling + +RoPE scaling can be used to increase the sequence length of the model during the inference time without necessarily fine-tuning it. To enable RoPE scaling, simply pass `--rope-scaling`, `--max-input-length` and `--rope-factors` flags when running through CLI. `--rope-scaling` can take the values `linear` or `dynamic`. If your model is not fine-tuned to a longer sequence length, use `dynamic`. `--rope-factor` is the ratio between the intended max sequence length and the model's original max sequence length. Make sure to pass `--max-input-length` to provide maximum input length for extension. + + + +We recommend using `dynamic` RoPE scaling. + + + +## Safetensors + +[Safetensors](https://github.com/huggingface/safetensors) is a fast and safe persistence format for deep learning models, and is required for tensor parallelism. TGI supports `safetensors` model loading under the hood. By default, given a repository with `safetensors` and `pytorch` weights, TGI will always load `safetensors`. If there's no `pytorch` weights, TGI will convert the weights to `safetensors` format. diff --git a/docs/source/basic_tutorials/safety.md b/docs/source/basic_tutorials/safety.md new file mode 100644 index 00000000..0b865db4 --- /dev/null +++ b/docs/source/basic_tutorials/safety.md @@ -0,0 +1,31 @@ +# Model safety. + +[Pytorch uses pickle](https://pytorch.org/docs/master/generated/torch.load.html) by default meaning that for quite a long while +*Every* model using that format is potentially executing unintended code while purely loading the model. + +There is a big red warning on Python's page for pickle [link](https://docs.python.org/3/library/pickle.html) but for quite a while +this was ignored by the community. Now that AI/ML is getting used much more ubiquitously we need to switch away from this format. + +HuggingFace is leading the effort here by creating a new format which contains pure data ([safetensors](https://github.com/huggingface/safetensors)) +and moving slowly but surely all the libs to make use of it by default. +The move is intentionnally slow in order to make breaking changes as little impact as possible on users throughout. + + +# TGI 2.0 + +Since the release of TGI 2.0, we take the opportunity of this major version increase to break backward compatibility for these pytorch +models (since they are a huge security risk for anyone deploying them). + + +From now on, TGI will not convert automatically pickle files without having `--trust-remote-code` flag or `TRUST_REMOTE_CODE=true` in the environment variables. +This flag is already used for community defined inference code, and is therefore quite representative of the level of confidence you are giving the model providers. + + +If you want to use a model that uses pickle, but you still do not want to trust the authors entirely we recommend making a convertion on our space made for that. + +https://huggingface.co/spaces/safetensors/convert + +This space will create a PR on the original model, which you are use directly regardless of merge status from the original authors. Just use +``` +docker run .... --revision refs/pr/#ID # Or use REVISION=refs/pr/#ID in the environment +``` diff --git a/docs/source/basic_tutorials/train_medusa.md b/docs/source/basic_tutorials/train_medusa.md new file mode 100644 index 00000000..ba2e43b7 --- /dev/null +++ b/docs/source/basic_tutorials/train_medusa.md @@ -0,0 +1,208 @@ +# Train Medusa + +This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation) for more information on how Medusa works and speculation in general. + +## What are the benefits of training a Medusa model? + +Training Medusa heads can greatly improve the speed of generation. Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training. + +One of the most important things is to have a good dataset (with similar data to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain. + +If you train Medusa on a dataset that is very different from the one you will use in production then the model will not be able to predict the future tokens accurately and consequently the speedup will be minimal or non-existent. + +## Self-distillation (Generating data for training) + +There are many methods for preparing data for training, but one of the easiest and most effective ways is to "self-distill" the data. This means that you can use the same model to generate the data that you will use to train the model. + +Essentially, you prompt the model with a similar input to what you will use in production and the model will generate the output. + +We'll use this output to help train the medusa heads to predict the `n+1`, `n+2`, `n+3`, etc tokens in the sequence. + +## Training + +The original implementation of Medusa is available at [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) and we'll follow a very similar process to train the model as described on the original repository. + +### Getting Started + +There are two methods for training the model: + +- `torchrun` that is a wrapper around `torch.distributed.launch` +- a forked version of `axlotl` that supports Medusa + +In this tutorial we'll use `torchrun` to train the model as it is the most straightforward way to train the model but similar steps can be followed to train the model using `axlotl` if you prefer. + +### Training with `torchrun` + +```bash +mkdir medusa-training +cd medusa-training + +pyenv install 3.10 +pyenv local 3.10 + +uv venv -p 3.10 +source .venv/bin/activate +``` + +Now lets clone the original `Medusa` repository and install the library. + +```bash +git clone https://github.com/FasterDecoding/Medusa.git +cd Medusa +pip install -e . +``` + +Next we'll need some data to train on, we can use the `ShareGPT_Vicuna_unfiltered` dataset that is available on the Hugging Face Hub. + +```bash +apt install git-lfs +git lfs install +git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered +``` + +Currently our directory structure looks like this: + +```bash +. +├── assets +├── CITATION.cff +├── create_data.py +├── data_generation +├── deepspeed.json +├── last_run_prepared +├── LICENSE +├── llm_judge +├── medusa +├── medusa_llm.egg-info +├── mistral.json +├── notebooks +├── pyproject.toml +├── README.md +├── ROADMAP.md +├── scripts +├── ShareGPT_Vicuna_unfiltered +│   ├── README.md +│   ├── ShareGPT_2023.05.04v0_Wasteland_Edition.json +│   └── ShareGPT_V4.3_unfiltered_cleaned_split.json +├── simple_gradio_interface.py +├── tiny-llama.json +└── vicuna_7b_qlora_stage1 +``` + +## Start Training + +Now the lets generate the data and start training the model. This process will take a while since we are generating data from the model. + +First make sure you have an instance of TGI running with the model you want to use for self-distillation. + +```bash +model=HuggingFaceH4/zephyr-7b-beta +volume=/home/ubuntu/.cache/huggingface/hub/ + +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model +``` + +Now we can generate the data using the `create_data.py` script. + +```bash +python create_data.py \ + --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \ + --output-filename zephyr_self_distill.json +``` + +At this point our terminal should look like this: + +
+ +
+ +> Note: In the screen shot above we are only using a the first 500 examples from the dataset to speed up the process, you should have a much larger dataset for training. + +Now we can finally get to the fun part and start training the model! + +Using `torchrun` we can easily launch the `medusa` training script with the `zephyr_self_distill.json` configuration file. + +> NOTE: If you just self-distilled you may still have the model running, make sure to stop it before starting the training in order to allow all of the resources to be used for training. + +```bash +WANDB_MODE=offline torchrun --nproc_per_node=4 medusa/train/train_legacy.py \ + --model_name_or_path HuggingFaceH4/zephyr-7b-beta \ + --data_path zephyr_self_distill.json \ + --bf16 True \ + --output_dir zephyr_out \ + --num_train_epochs 5 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "no" \ + --learning_rate 1e-3 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --lazy_preprocess True \ + --medusa_num_heads 3 \ + --medusa_num_layers 1 \ + --deepspeed deepspeed.json +``` + +
+ +
+ +If successful, you should see the similar output to the one below: + +```bash +wandb: Run history: +wandb: train/epoch ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███ +wandb: train/global_step ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███ +wandb: train/learning_rate ▅███▇▇▆▅▅▄▃▂▂▁▁▁ +wandb: train/loss ██▆▄▄▃▃▂▂▃▁▁▂▁▁▁ +wandb: train/medusa0_loss ▆▆▇▆▆▅▄▅▃▃▃▃▂▂▂▂▂▃▂▂▂▁▁▁▂▁▁▁▁▁█▁▁▁▂▁▁▁▁▁ +wandb: train/medusa0_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▄▄▄▃▄▃▄▄▅▅▆▅▆▆▇▅▇▇▄▇█▇▅▇█▆▇▇ +wandb: train/medusa1_loss ▇▇█▇▇▆▅▅▃▄▃▃▃▃▃▃▃▃▃▃▂▁▂▂▂▁▁▂▁▁▇▁▁▁▂▁▁▁▁▁ +wandb: train/medusa1_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▃▄▄▃▃▂▃▃▅▅▆▄█▆▇▅▇▇▅█▇▇▅▇█▆▆▇ +wandb: train/medusa2_loss ▃▃▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁█▁▁▁▂▁▁▁▁▁ +wandb: train/medusa2_top1 ▁▁▁▂▁▁▁▁▂▂▃▃▃▄▄▃▃▂▃▃▅▆▅▄█▆▆▅▆▆▄█▇▇▄▇█▆▆▇ +wandb: train/total_flos ▁ +wandb: train/train_loss ▁ +wandb: train/train_runtime ▁ +wandb: train/train_samples_per_second ▁ +wandb: train/train_steps_per_second ▁ +wandb: +wandb: Run summary: +wandb: train/epoch 2.0 +wandb: train/global_step 16 +wandb: train/learning_rate 0.0 +wandb: train/loss 14.8906 +wandb: train/medusa0_loss 4.25 +wandb: train/medusa0_top1 0.28809 +wandb: train/medusa1_loss 4.8125 +wandb: train/medusa1_top1 0.22727 +wandb: train/medusa2_loss 5.5 +wandb: train/medusa2_top1 0.17293 +wandb: train/total_flos 0.0 +wandb: train/train_loss 23.98242 +wandb: train/train_runtime 396.9266 +wandb: train/train_samples_per_second 2.519 +wandb: train/train_steps_per_second 0.04 +``` + +Last but most importantly, don't forget to push this model to the Hugging Face Hub so you can use it in your projects. + +```bash +python -m medusa.hf_utils \ + --folder zephyr_out_medusa_mlp_zephyr-7b-beta_medusa_3_lr_0.001_layers_1 \ + --repo drbh/zephyr_medusa_demo +``` + +Woo, we've successfully trained a Medusa model and pushed it to the Hugging Face Hub! 🎉 diff --git a/docs/source/basic_tutorials/using_cli.md b/docs/source/basic_tutorials/using_cli.md new file mode 100644 index 00000000..64554069 --- /dev/null +++ b/docs/source/basic_tutorials/using_cli.md @@ -0,0 +1,35 @@ +# Using TGI CLI + +You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli). + +`text-generation-server` lets you download the model with `download-weights` command like below 👇 + +```bash +text-generation-server download-weights MODEL_HUB_ID +``` + +You can also use it to quantize models like below 👇 + +```bash +text-generation-server quantize MODEL_HUB_ID OUTPUT_DIR +``` + +You can use `text-generation-launcher` to serve models. + +```bash +text-generation-launcher --model-id MODEL_HUB_ID --port 8080 +``` + +There are many options and parameters you can pass to `text-generation-launcher`. The documentation for CLI is kept minimal and intended to rely on self-generating documentation, which can be found by running + +```bash +text-generation-launcher --help +``` + +You can also find it hosted in this [Swagger UI](https://huggingface.github.io/text-generation-inference/). + +Same documentation can be found for `text-generation-server`. + +```bash +text-generation-server --help +``` diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md new file mode 100644 index 00000000..d0008fdb --- /dev/null +++ b/docs/source/basic_tutorials/using_guidance.md @@ -0,0 +1,359 @@ +# Guidance + +Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs. + +These feature are available starting from version `1.4.3`. They are accessible via the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! + +_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `/chat/completions` endpoint._ + +## How it works + +TGI leverages the [outlines](https://github.com/outlines-dev/outlines) library to efficiently parse and compile the grammatical structures and tools specified by users. This integration transforms the defined grammars into an intermediate representation that acts as a framework to guide and constrain content generation, ensuring that outputs adhere to the specified grammatical rules. + +If you are interested in the technical details on how outlines is used in TGI, you can check out the [conceptual guidance documentation](../conceptual/guidance). + +## Table of Contents 📚 + +### Grammar and Constraints + +- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision. +- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models. +- [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema. +- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses. + +### Tools and Functions + +- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions. +- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions. +- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +## Grammar and Constraints 🛣️ + +### The Grammar Parameter + +In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the LLM. + +Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability. + +```json +curl localhost:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park", + "parameters": { + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] + } + } + } +}' +// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"} + +``` + +### Hugging Face Hub Python Library + +The Hugging Face Hub Python library provides a client that makes it easy to interact with the Messages API. Here's an example of how to use the client to send a request with a grammar parameter. + +```python +from huggingface_hub import InferenceClient + +client = InferenceClient("http://localhost:3000") + +schema = { + "properties": { + "location": {"title": "Location", "type": "string"}, + "activity": {"title": "Activity", "type": "string"}, + "animals_seen": { + "maximum": 5, + "minimum": 1, + "title": "Animals Seen", + "type": "integer", + }, + "animals": {"items": {"type": "string"}, "title": "Animals", "type": "array"}, + }, + "required": ["location", "activity", "animals_seen", "animals"], + "title": "Animals", + "type": "object", +} + +user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park" +resp = client.text_generation( + f"convert to JSON: 'f{user_input}'. please use the following schema: {schema}", + max_new_tokens=100, + seed=42, + grammar={"type": "json", "value": schema}, +) + +print(resp) +# { "activity": "bike ride", "animals": ["puppy", "cat", "raccoon"], "animals_seen": 3, "location": "park" } + +``` + +A grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The LLM will then generate a response that conforms to the specified grammar. + +> Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster. + +### Constrain with Pydantic + +Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way. + +```python +from huggingface_hub import InferenceClient +from pydantic import BaseModel, conint +from typing import List + + +class Animals(BaseModel): + location: str + activity: str + animals_seen: conint(ge=1, le=5) # Constrained integer type + animals: List[str] + + +client = InferenceClient("http://localhost:3000") + +user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park" +resp = client.text_generation( + f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.schema()}", + max_new_tokens=100, + seed=42, + grammar={"type": "json", "value": Animals.schema()}, +) + +print(resp) +# { "activity": "bike ride", "animals": ["puppy", "cat", "raccoon"], "animals_seen": 3, "location": "park" } + + +``` + +defining a grammar as regular expressions + +```python +from huggingface_hub import InferenceClient + +client = InferenceClient("http://localhost:3000") + +regexp = "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)" + +resp = client.text_generation( + f"Whats Googles DNS? Please use the following regex: {regexp}", + seed=42, + grammar={ + "type": "regex", + "value": regexp, + }, +) + + +print(resp) +# 7.1.1.1 + +``` + +## Tools and Functions 🛠️ + +### The Tools Parameter + +In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API. + +Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. + +Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. + +```json +curl localhost:3000/v1/chat/completions \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is the weather like in New York?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + } + } + } + ], + "tool_choice": "get_current_weather" +}' +// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}} +``` + +### Chat Completion with Tools + +Grammars are supported in the `/generate` endpoint, while tools are supported in the `/chat/completions` endpoint. Here's an example of how to use the client to send a request with a tool parameter. + +```python +from huggingface_hub import InferenceClient + +client = InferenceClient("http://localhost:3000") + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + }, +] + +chat = client.chat_completion( + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + tools=tools, + seed=42, + max_tokens=100, +) + +print(chat.choices[0].message.tool_calls) +# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')] + +``` + +### OpenAI integration + +TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. + +```python +from openai import OpenAI + +# Initialize the client, pointing it to one of the available models +client = OpenAI( + base_url="http://localhost:3000/v1", + api_key="_", +) + +# NOTE: tools defined above and removed for brevity + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + { + "role": "system", + "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + tools=tools, + tool_choice="auto", # tool selected by model + max_tokens=500, +) + + +called = chat_completion.choices[0].message.tool_calls +print(called) +# { +# "id": 0, +# "type": "function", +# "function": { +# "description": None, +# "name": "tools", +# "parameters": { +# "format": "celsius", +# "location": "San Francisco, CA", +# "num_days": 3, +# }, +# }, +# } +``` diff --git a/docs/source/basic_tutorials/visual_language_models.md b/docs/source/basic_tutorials/visual_language_models.md new file mode 100644 index 00000000..3770db0b --- /dev/null +++ b/docs/source/basic_tutorials/visual_language_models.md @@ -0,0 +1,230 @@ +# Vision Language Model Inference in TGI + +Visual Language Model (VLM) are models that consume both image and text inputs to generate text. + +VLM's are trained on a combination of image and text data and can handle a wide range of tasks, such as image captioning, visual question answering, and visual dialog. + +> What distinguishes VLMs from other text and image models is their ability to handle long context and generate text that is coherent and relevant to the image even after multiple turns or in some cases, multiple images. + +Below are couple of common use cases for vision language models: + +- **Image Captioning**: Given an image, generate a caption that describes the image. +- **Visual Question Answering (VQA)**: Given an image and a question about the image, generate an answer to the question. +- **Mulimodal Dialog**: Generate response to multiple turns of images and conversations. +- **Image Information Retrieval**: Given an image, retrieve information from the image. + +## How to Use a Vision Language Model? + +### Hugging Face Hub Python Library + +To infer with vision language models through Python, you can use the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The `InferenceClient` class provides a simple way to interact with the [Inference API](https://huggingface.co/docs/api-inference/index). Images can be passed as URLs or base64-encoded strings. The `InferenceClient` will automatically detect the image format. + +```python +from huggingface_hub import InferenceClient + +client = InferenceClient("http://127.0.0.1:3000") +image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" +prompt = f"![]({image})What is this a picture of?\n\n" +for token in client.text_generation(prompt, max_new_tokens=16, stream=True): + print(token) + +# This is a picture of an anthropomorphic rabbit in a space suit. +``` + +```python +from huggingface_hub import InferenceClient +import base64 +import requests +import io + +client = InferenceClient("http://127.0.0.1:3000") + +# read image from local file +image_path = "rabbit.png" +with open(image_path, "rb") as f: + image = base64.b64encode(f.read()).decode("utf-8") + +image = f"data:image/png;base64,{image}" +prompt = f"![]({image})What is this a picture of?\n\n" + +for token in client.text_generation(prompt, max_new_tokens=10, stream=True): + print(token) + +# This is a picture of an anthropomorphic rabbit in a space suit. +``` + +or via the `chat_completion` endpoint: + +```python +from huggingface_hub import InferenceClient + +client = InferenceClient("http://127.0.0.1:3000") + +chat = client.chat_completion( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + ], + }, + ], + seed=42, + max_tokens=100, +) + +print(chat) +# ChatCompletionOutput(choices=[ChatCompletionOutputComplete(finish_reason='length', index=0, message=ChatCompletionOutputMessage(role='assistant', content=" The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\n\nThe artwork style is that of a", name=None, tool_calls=None), logprobs=None)], created=1714589614, id='', model='llava-hf/llava-v1.6-mistral-7b-hf', object='text_completion', system_fingerprint='2.0.2-native', usage=ChatCompletionOutputUsage(completion_tokens=100, prompt_tokens=2943, total_tokens=3043)) + +``` + +or with OpenAi's library: + +```python +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI(base_url="http://localhost:3000/v1", api_key="-") + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + ], + }, + ], + stream=False, +) + +print(chat_completion) +# ChatCompletion(id='', choices=[Choice(finish_reason='eos_token', index=0, logprobs=None, message=ChatCompletionMessage(content=' The image depicts an anthropomorphic rabbit dressed in a space suit with gear that resembles NASA attire. The setting appears to be a solar eclipse with dramatic mountain peaks and a partial celestial body in the sky. The artwork is detailed and vivid, with a warm color palette and a sense of an adventurous bunny exploring or preparing for a journey beyond Earth. ', role='assistant', function_call=None, tool_calls=None))], created=1714589732, model='llava-hf/llava-v1.6-mistral-7b-hf', object='text_completion', system_fingerprint='2.0.2-native', usage=CompletionUsage(completion_tokens=84, prompt_tokens=2943, total_tokens=3027)) +``` + +### Inference Through Sending `cURL` Requests + +To use the `generate_stream` endpoint with curl, you can add the `-N` flag. This flag disables curl default buffering and shows data as it arrives from the server. + +```bash +curl -N 127.0.0.1:3000/generate_stream \ + -X POST \ + -d '{"inputs":"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\n\n","parameters":{"max_new_tokens":16, "seed": 42}}' \ + -H 'Content-Type: application/json' + +# ... +# data:{"index":16,"token":{"id":28723,"text":".","logprob":-0.6196289,"special":false},"generated_text":"This is a picture of an anthropomorphic rabbit in a space suit.","details":null} +``` + +### Inference Through JavaScript + +First, we need to install the `@huggingface/inference` library. + +```bash +npm install @huggingface/inference +``` + +If you're using the free Inference API, you can use [Huggingface.js](https://huggingface.co/docs/huggingface.js/inference/README)'s `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint` class to easily interact with the Inference API. + +We can create a `HfInferenceEndpoint` providing our endpoint URL and We can create a `HfInferenceEndpoint` providing our endpoint URL and [Hugging Face access token](https://huggingface.co/settings/tokens). + +```js +import { HfInferenceEndpoint } from "@huggingface/inference"; + +const hf = new HfInferenceEndpoint("http://127.0.0.1:3000", "HF_TOKEN"); + +const prompt = + "![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\n\n"; + +const stream = hf.textGenerationStream({ + inputs: prompt, + parameters: { max_new_tokens: 16, seed: 42 }, +}); +for await (const r of stream) { + // yield the generated token + process.stdout.write(r.token.text); +} + +// This is a picture of an anthropomorphic rabbit in a space suit. +``` + +## Combining Vision Language Models with Other Features + +VLMs in TGI have several advantages, for example these models can be used in tandem with other features for more complex tasks. For example, you can use VLMs with [Guided Generation](/docs/conceptual/guided-generation) to generate specific JSON data from an image. + +
+ +
+ +For example we can extract information from the rabbit image and generate a JSON object with the location, activity, number of animals seen, and the animals seen. That would look like this: + +```json +{ + "activity": "Standing", + "animals": ["Rabbit"], + "animals_seen": 1, + "location": "Rocky surface with mountains in the background and a red light on the rabbit's chest" +} +``` + +All we need to do is provide a JSON schema to the VLM model and it will generate the JSON object for us. + +```bash +curl localhost:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs":"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\n\n", + "parameters": { + "max_new_tokens": 100, + "seed": 42, + "grammar": { + "type": "json", + "value": { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] + } + } + } +}' + +# { +# "generated_text": "{ \"activity\": \"Standing\", \"animals\": [ \"Rabbit\" ], \"animals_seen\": 1, \"location\": \"Rocky surface with mountains in the background and a red light on the rabbit's chest\" }" +# } +``` + +Want to learn more about how Vision Language Models work? Check out the [awesome blog post on the topic](https://huggingface.co/blog/vlms). diff --git a/docs/source/conceptual/flash_attention.md b/docs/source/conceptual/flash_attention.md new file mode 100644 index 00000000..6b13cd13 --- /dev/null +++ b/docs/source/conceptual/flash_attention.md @@ -0,0 +1,11 @@ +# Flash Attention + +Scaling the transformer architecture is heavily bottlenecked by the self-attention mechanism, which has quadratic time and memory complexity. Recent developments in accelerator hardware mainly focus on enhancing compute capacities and not memory and transferring data between hardware. This results in attention operation having a memory bottleneck. **Flash Attention** is an attention algorithm used to reduce this problem and scale transformer-based models more efficiently, enabling faster training and inference. + +Standard attention mechanism uses High Bandwidth Memory (HBM) to store, read and write keys, queries and values. HBM is large in memory, but slow in processing, meanwhile SRAM is smaller in memory, but faster in operations. In the standard attention implementation, the cost of loading and writing keys, queries, and values from HBM is high. It loads keys, queries, and values from HBM to GPU on-chip SRAM, performs a single step of the attention mechanism, writes it back to HBM, and repeats this for every single attention step. Instead, Flash Attention loads keys, queries, and values once, fuses the operations of the attention mechanism, and writes them back. + +![Flash Attention](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/flash-attn.png) + +It is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix. + +You can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135). diff --git a/docs/source/conceptual/guidance.md b/docs/source/conceptual/guidance.md new file mode 100644 index 00000000..3059e3de --- /dev/null +++ b/docs/source/conceptual/guidance.md @@ -0,0 +1,86 @@ +# Guidance + +## What is Guidance? + +Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON. + +## How is it used? + +Guidance can be implemented in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance: + +Technically, guidance can be used to generate: + +- a specific JSON object +- a function signature +- typed output like a list of integers + +However these use cases can span a wide range of applications, such as: + +- extracting structured data from unstructured text +- summarizing text into a specific format +- limit output to specific classes of words (act as a LLM powered classifier) +- generate the input to specific APIs or services +- provide reliable and consistent output for downstream tasks +- extract data from multimodal inputs + +## How it works? + +Diving into the details, guidance is enabled by including a grammar with a generation request that is compiled, and used to modify the chosen tokens. + +This process can be broken down into the following steps: + +1. A request is sent to the backend, it is processed and placed in batch. Processing includes compiling the grammar into a finite state machine and a grammar state. + +
+ + +
+ +2. The model does a forward pass over the batch. This returns probabilities for each token in the vocabulary for each request in the batch. + +3. The process of choosing one of those tokens is called `sampling`. The model samples from the distribution of probabilities to choose the next token. In TGI all of the steps before sampling are called `processor`. Grammars are applied as a processor that masks out tokens that are not allowed by the grammar. + +
+ + +
+ +4. The grammar mask is applied and the model samples from the remaining tokens. Once a token is chosen, we update the grammar state with the new token, to prepare it for the next pass. + +
+ + +
+ +## How to use Guidance? + +There are two main ways to use guidance; you can either use the `/generate` endpoint with a grammar or use the `/chat/completion` endpoint with tools. + +Under the hood tools are a special case of grammars that allows the model to choose one or none of the provided tools. + +Please refer to [using guidance](../basic_tutorials/using_guidance) for more examples and details on how to use guidance in Python, JavaScript, and cURL. + +### Getting the most out of guidance + +Depending on how you are using guidance, you may want to make use of different features. Here are some tips to get the most out of guidance: + +- If you are using the `/generate` with a `grammar` it is recommended to include the grammar in the prompt prefixed by something like `Please use the following JSON schema to generate the output:`. This will help the model understand the context of the grammar and generate the output accordingly. +- If you are getting a response with many repeated tokens, please use the `frequency_penalty` or `repetition_penalty` to reduce the number of repeated tokens in the output. diff --git a/docs/source/conceptual/paged_attention.md b/docs/source/conceptual/paged_attention.md new file mode 100644 index 00000000..3fb2dcd8 --- /dev/null +++ b/docs/source/conceptual/paged_attention.md @@ -0,0 +1,9 @@ +# PagedAttention + +LLMs struggle with memory limitations during generation. In the decoding part of generation, all the attention keys and values generated for previous tokens are stored in GPU memory for reuse. This is called _KV cache_, and it may take up a large amount of memory for large models and long sequences. + +PagedAttention attempts to optimize memory use by partitioning the KV cache into blocks that are accessed through a lookup table. Thus, the KV cache does not need to be stored in contiguous memory, and blocks are allocated as needed. The memory efficiency can increase GPU utilization on memory-bound workloads, so more inference batches can be supported. + +The use of a lookup table to access the memory blocks can also help with KV sharing across multiple generations. This is helpful for techniques such as _parallel sampling_, where multiple outputs are generated simultaneously for the same prompt. In this case, the cached KV blocks can be shared among the generations. + +TGI's PagedAttention implementation leverages the custom cuda kernels developed by the [vLLM Project](https://github.com/vllm-project/vllm). You can learn more about this technique in the [project's page](https://vllm.ai/). diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md new file mode 100644 index 00000000..8f26fdba --- /dev/null +++ b/docs/source/conceptual/quantization.md @@ -0,0 +1,59 @@ +# Quantization + +TGI offers GPTQ and bits-and-bytes quantization to quantize large language models. + +## Quantization with GPTQ + +GPTQ is a post-training quantization method to make the model smaller. It quantizes the layers by finding a compressed version of that weight, that will yield a minimum mean squared error like below 👇 + +Given a layer \\(l\\) with weight matrix \\(W_{l}\\) and layer input \\(X_{l}\\), find quantized weight \\(\\hat{W}_{l}\\): + +$$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$ + + +TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize gptq +``` + +Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. + +To quantize a given model using GPTQ with a calibration dataset, simply run + +```bash +text-generation-server quantize tiiuae/falcon-40b /data/falcon-40b-gptq +# Add --upload-to-model-id MYUSERNAME/falcon-40b to push the created model to the hub directly +``` + +This will create a new directory with the quantized files which you can use with, + +```bash +text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num-shard 2 --quantize gptq +``` + +You can learn more about the quantization options by running `text-generation-server quantize --help`. + +If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). + +## Quantization with bitsandbytes + +bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision. + +8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much. +In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes +``` + +4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. + +In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4 +``` + +You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). diff --git a/docs/source/conceptual/safetensors.md b/docs/source/conceptual/safetensors.md new file mode 100644 index 00000000..8ede20fe --- /dev/null +++ b/docs/source/conceptual/safetensors.md @@ -0,0 +1,7 @@ +# Safetensors + +Safetensors is a model serialization format for deep learning models. It is [faster](https://huggingface.co/docs/safetensors/speed) and safer compared to other serialization formats like pickle (which is used under the hood in many deep learning libraries). + +TGI depends on safetensors format mainly to enable [tensor parallelism sharding](./tensor_parallelism). For a given model repository during serving, TGI looks for safetensors weights. If there are no safetensors weights, TGI converts the PyTorch weights to safetensors format. + +You can learn more about safetensors by reading the [safetensors documentation](https://huggingface.co/docs/safetensors/index). diff --git a/docs/source/conceptual/speculation.md b/docs/source/conceptual/speculation.md new file mode 100644 index 00000000..45618ae3 --- /dev/null +++ b/docs/source/conceptual/speculation.md @@ -0,0 +1,49 @@ +## Speculation + + +Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea. +The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid. + +So you are making *more* computations on your LLM, but if you are correct you produce 1, 2, 3 etc.. tokens on a single LLM pass. Since LLMs are usually memory bound (and not compute bound), provided your guesses are correct enough, this is a 2-3x faster inference (It can be much more for code oriented tasks for instance). + +You can check a more [detailed explanation](https://huggingface.co/blog/assisted-generation). + +Text-generation inference supports 2 main speculative methods: + +- Medusa +- N-gram + + +### Medusa + + +Medusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models. + + +You can check a few existing fine-tunes for popular models: + +- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa) +- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa) +- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa) + + +In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md) + + +In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically. + + +### N-gram + + +If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`. +N-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens "np.mean" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens "np." is probably also "mean". + +This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much. + + +In order to enable n-gram speculation simply use + +`--speculate 2` in your flags. + +[Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate) diff --git a/docs/source/conceptual/streaming.md b/docs/source/conceptual/streaming.md new file mode 100644 index 00000000..71ec9b25 --- /dev/null +++ b/docs/source/conceptual/streaming.md @@ -0,0 +1,146 @@ +# Streaming + +## What is Streaming? + +Token streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience. + +
+ + +
+ +With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality before the end of the generation. This has different positive effects: + +* Users can get results orders of magnitude earlier for extremely long queries. +* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect. +* Perceived latency is lower when results are shown in the early stages. +* When used in conversational UIs, the experience feels more natural. + +For example, a system can generate 100 tokens per second. If the system generates 1000 tokens, with the non-streaming setup, users need to wait 10 seconds to get results. On the other hand, with the streaming setup, users get initial results immediately, and although end-to-end latency will be the same, they can see half of the generation after five seconds. Below you can see an interactive demo that shows non-streaming vs streaming side-by-side. Click **generate** below. + +
+ +
+ + +## How to use Streaming? + +### Streaming with Python + +To stream tokens with `InferenceClient`, simply pass `stream=True` and iterate over the response. + +```python +from huggingface_hub import InferenceClient + +client = InferenceClient("http://127.0.0.1:8080") +for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): + print(token) + +# To +# make +# cheese +#, +# you +# need +# to +# start +# with +# milk +#. +``` + +If you want additional details, you can add `details=True`. In this case, you get a `TextGenerationStreamResponse` which contains additional information such as the probabilities and the tokens. For the final response in the stream, it also returns the full generated text. + +```python +for details in client.text_generation("How do you make cheese?", max_new_tokens=12, details=True, stream=True): + print(details) + +#TextGenerationStreamResponse(token=Token(id=193, text='\n', logprob=-0.007358551, special=False), generated_text=None, details=None) +#TextGenerationStreamResponse(token=Token(id=2044, text='To', logprob=-1.1357422, special=False), generated_text=None, details=None) +#TextGenerationStreamResponse(token=Token(id=717, text=' make', logprob=-0.009841919, special=False), generated_text=None, details=None) +#... +#TextGenerationStreamResponse(token=Token(id=25, text='.', logprob=-1.3408203, special=False), generated_text='\nTo make cheese, you need to start with milk.', details=StreamDetails(finish_reason=, generated_tokens=12, seed=None)) +``` + +The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case you need to handle the requests concurrently. + +```python +from huggingface_hub import AsyncInferenceClient + +client = AsyncInferenceClient("http://127.0.0.1:8080") +async for token in await client.text_generation("How do you make cheese?", stream=True): + print(token) + +# To +# make +# cheese +#, +# you +# need +# to +# start +# with +# milk +#. +``` + +### Streaming with cURL + +To use the `generate_stream` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server + +```curl +curl -N 127.0.0.1:8080/generate_stream \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -H 'Content-Type: application/json' +``` + +### Streaming with JavaScript + +First, we need to install the `@huggingface/inference` library. +`npm install @huggingface/inference` + +If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`. + +We can create a `HfInferenceEndpoint` providing our endpoint URL and credential. + +```js +import { HfInferenceEndpoint } from '@huggingface/inference' + +const hf = new HfInferenceEndpoint('https://YOUR_ENDPOINT.endpoints.huggingface.cloud', 'hf_YOUR_TOKEN') + +// prompt +const prompt = 'What can you do in Nuremberg, Germany? Give me 3 Tips' + +const stream = hf.textGenerationStream({ inputs: prompt }) +for await (const r of stream) { + // yield the generated token + process.stdout.write(r.token.text) +} +``` + +## How does Streaming work under the hood? + +Under the hood, TGI uses Server-Sent Events (SSE). In an SSE Setup, a client sends a request with the data, opening an HTTP connection and subscribing to updates. Afterward, the server sends data to the client. There is no need for further requests; the server will keep sending the data. SSEs are unidirectional, meaning the client does not send other requests to the server. SSE sends data over HTTP, making it easy to use. + +SSEs are different than: +* Polling: where the client keeps calling the server to get data. This means that the server might return empty responses and cause overhead. +* Webhooks: where there is a bi-directional connection. The server can send information to the client, but the client can also send data to the server after the first request. Webhooks are more complex to operate as they don’t only use HTTP. + +If there are too many requests at the same time, TGI returns an HTTP Error with an `overloaded` error type (`huggingface_hub` returns `OverloadedError`). This allows the client to manage the overloaded server (e.g., it could display a busy error to the user or retry with a new request). To configure the maximum number of concurrent requests, you can specify `--max_concurrent_requests`, allowing clients to handle backpressure. diff --git a/docs/source/conceptual/tensor_parallelism.md b/docs/source/conceptual/tensor_parallelism.md new file mode 100644 index 00000000..2c241c41 --- /dev/null +++ b/docs/source/conceptual/tensor_parallelism.md @@ -0,0 +1,14 @@ +# Tensor Parallelism + +Tensor parallelism is a technique used to fit a large model in multiple GPUs. For example, when multiplying the input tensors with the first weight tensor, the matrix multiplication is equivalent to splitting the weight tensor column-wise, multiplying each column with the input separately, and then concatenating the separate outputs. These outputs are then transferred from the GPUs and concatenated together to get the final result, like below 👇 + +![Image courtesy of Anton Lozkhov](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/TP.png) + + + + +Tensor Parallelism only works for [models officially supported](../supported_models), it will not work when falling back to `transformers`. You can get more information about unsupported models [here](../basic_tutorials/non_core_models). + + + +You can learn a lot more details about tensor-parallelism from [the `transformers` docs](https://huggingface.co/docs/transformers/main/en/perf_train_gpu_many#tensor-parallelism). diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 00000000..309442b1 --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,28 @@ +# Text Generation Inference + +Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and T5. + +![Text Generation Inference](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) + +Text Generation Inference implements many optimizations and features, such as: + +- Simple launcher to serve most popular LLMs +- Production ready (distributed tracing with Open Telemetry, Prometheus metrics) +- Tensor Parallelism for faster inference on multiple GPUs +- Token streaming using Server-Sent Events (SSE) +- Continuous batching of incoming requests for increased total throughput +- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures +- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) +- [Safetensors](https://github.com/huggingface/safetensors) weight loading +- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) +- Logits warper (temperature scaling, top-p, top-k, repetition penalty) +- Stop sequences +- Log probabilities +- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance. +- [Guidance](../conceptual/guidance): Enable function calling and tool-use by forcing the model to generate structured outputs based on your own predefined output schemas. + +Text Generation Inference is used in production by multiple projects, such as: + +- [Hugging Chat](https://github.com/huggingface/chat-ui), an open-source interface for open-access models, such as Open Assistant and Llama +- [OpenAssistant](https://open-assistant.io/), an open-source community effort to train LLMs in the open +- [nat.dev](http://nat.dev/), a playground to explore and compare LLMs. diff --git a/docs/source/installation.md b/docs/source/installation.md new file mode 100644 index 00000000..b6c24d55 --- /dev/null +++ b/docs/source/installation.md @@ -0,0 +1,83 @@ +# Installation from source + + + +Installing TGI from source is not the recommended usage. We strongly recommend to use TGI through Docker, check the [Quick Tour](./quicktour), [Installation for Nvidia GPUs](./installation_nvidia) and [Installation for AMD GPUs](./installation_amd) to learn how to use TGI with Docker. + + + +## Install CLI + +You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. + +To install the CLI, you need to first clone the TGI repository and then run `make`. + +```bash +git clone https://github.com/huggingface/text-generation-inference.git && cd text-generation-inference +make install +``` + +If you would like to serve models with custom kernels, run + +```bash +BUILD_EXTENSIONS=True make install +``` + +## Local Installation from Source + +Before you start, you will need to setup your environment, and install Text Generation Inference. Text Generation Inference is tested on **Python 3.9+**. + +Text Generation Inference is available on pypi, conda and GitHub. + +To install and launch locally, first [install Rust](https://rustup.rs/) and create a Python virtual environment with at least +Python 3.9, e.g. using conda: + +```bash +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + +conda create -n text-generation-inference python=3.9 +conda activate text-generation-inference +``` + +You may also need to install Protoc. + +On Linux: + +```bash +PROTOC_ZIP=protoc-21.12-linux-x86_64.zip +curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP +sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc +sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*' +rm -f $PROTOC_ZIP +``` + +On MacOS, using Homebrew: + +```bash +brew install protobuf +``` + +Then run to install Text Generation Inference: + +```bash +git clone https://github.com/huggingface/text-generation-inference.git && cd text-generation-inference +BUILD_EXTENSIONS=True make install +``` + + + +On some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run: + +```bash +sudo apt-get install libssl-dev gcc -y +``` + + + +Once installation is done, simply run: + +```bash +make run-falcon-7b-instruct +``` + +This will serve Falcon 7B Instruct model from the port 8080, which we can query. diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md new file mode 100644 index 00000000..d70953ae --- /dev/null +++ b/docs/source/installation_amd.md @@ -0,0 +1,38 @@ +# Using TGI with AMD GPUs + +TGI is supported and tested on [AMD Instinct MI210](https://www.amd.com/en/products/accelerators/instinct/mi200/mi210.html), [MI250](https://www.amd.com/en/products/accelerators/instinct/mi200/mi250.html) and [MI300](https://www.amd.com/en/products/accelerators/instinct/mi300.html) GPUs. The support may be extended in the future. The recommended usage is through Docker. Make sure to check the [AMD documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html) on how to use Docker with AMD GPUs. + +On a server powered by AMD GPUs, TGI can be launched with the following command: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ + --device=/dev/kfd --device=/dev/dri --group-add video \ + --ipc=host --shm-size 256g --net host -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:2.0.4-rocm \ + --model-id $model +``` + +The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide. + +## TunableOp + +TGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable), which allows to do an additional warmup to select the best performing matrix multiplication (GEMM) kernel from rocBLAS or hipBLASLt. + +Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3. + +TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container. + +## Flash attention implementation + +Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/flash_attn_triton.py). + +By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container. + +## Unsupported features + +The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: +* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints. +* Kernel for sliding window attention (Mistral) diff --git a/docs/source/installation_gaudi.md b/docs/source/installation_gaudi.md new file mode 100644 index 00000000..1ddf2b47 --- /dev/null +++ b/docs/source/installation_gaudi.md @@ -0,0 +1,3 @@ +# Using TGI with Intel Gaudi + +Check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index). diff --git a/docs/source/installation_inferentia.md b/docs/source/installation_inferentia.md new file mode 100644 index 00000000..0394e6de --- /dev/null +++ b/docs/source/installation_inferentia.md @@ -0,0 +1,3 @@ +# Using TGI with Inferentia + +Check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2. diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md new file mode 100644 index 00000000..9077f7fd --- /dev/null +++ b/docs/source/installation_nvidia.md @@ -0,0 +1,18 @@ +# Using TGI with Nvidia GPUs + +TGI optimized models are supported on NVIDIA [H100](https://www.nvidia.com/en-us/data-center/h100/), [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. + +For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. + +TGI can be used on NVIDIA GPUs through its official docker image: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:2.0.4 \ + --model-id $model +``` + +The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide. diff --git a/docs/source/messages_api.md b/docs/source/messages_api.md new file mode 100644 index 00000000..250aaae2 --- /dev/null +++ b/docs/source/messages_api.md @@ -0,0 +1,175 @@ +# Messages API + +Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility. + +> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature. + +#### Table of Contents + +- [Making a Request](#making-a-request) +- [Streaming](#streaming) +- [Synchronous](#synchronous) +- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints) +- [Cloud Providers](#cloud-providers) + - [Amazon SageMaker](#amazon-sagemaker) + +## Making a Request + +You can make a request to TGI's Messages API using `curl`. Here's an example: + +```bash +curl localhost:3000/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + +## Streaming + +You can also use OpenAI's Python client library to make a streaming request. Here's how: + +```python +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + base_url="http://localhost:3000/v1", + api_key="-" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message) +``` + +## Synchronous + +If you prefer to make a synchronous request, you can do so like this: + +```python +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + base_url="http://localhost:3000/v1", + api_key="-" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=False +) + +print(chat_completion) +``` + +## Hugging Face Inference Endpoints + +The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated). +Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library: + +> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key. + +```python +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + # replace with your endpoint url, make sure to include "v1/" at the end + base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/", + # replace with your API key + api_key="hf_XXX" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message.choices[0].delta.content, end="") +``` + +## Cloud Providers + +TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker: + +## Amazon SageMaker + +To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`. + +This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API. + +```python +import json +import sagemaker +import boto3 +from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri + +try: + role = sagemaker.get_execution_role() +except ValueError: + iam = boto3.client('iam') + role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn'] + +# Hub Model configuration. https://huggingface.co/models +hub = { + 'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta', + 'SM_NUM_GPUS': json.dumps(1), + 'MESSAGES_API_ENABLED': True +} + +# create Hugging Face Model Class +huggingface_model = HuggingFaceModel( + image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"), + env=hub, + role=role, +) + +# deploy model to SageMaker Inference +predictor = huggingface_model.deploy( + initial_instance_count=1, + instance_type="ml.g5.2xlarge", + container_startup_health_check_timeout=300, + ) + +# send request +predictor.predict({ +"messages": [ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ] +}) +``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md new file mode 100644 index 00000000..b84de85d --- /dev/null +++ b/docs/source/quicktour.md @@ -0,0 +1,94 @@ +# Quick Tour + +The easiest way of getting started is using the official Docker container. Install Docker following [their installation instructions](https://docs.docker.com/get-docker/). + +## Launching TGI + +Let's say you want to deploy [teknium/OpenHermes-2.5-Mistral-7B](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B) model with TGI on an Nvidia GPU. Here is an example on how to do that: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:2.0.4 \ + --model-id $model +``` + +### Supported hardware + +TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. + +## Consuming TGI + +Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. + + + + +```python +import requests + +headers = { + "Content-Type": "application/json", +} + +data = { + 'inputs': 'What is Deep Learning?', + 'parameters': { + 'max_new_tokens': 20, + }, +} + +response = requests.post('http://127.0.0.1:8080/generate', headers=headers, json=data) +print(response.json()) +# {'generated_text': '\n\nDeep Learning is a subset of Machine Learning that is concerned with the development of algorithms that can'} +``` + + + +```js +async function query() { + const response = await fetch( + 'http://127.0.0.1:8080/generate', + { + method: 'POST', + headers: { 'Content-Type': 'application/json'}, + body: JSON.stringify({ + 'inputs': 'What is Deep Learning?', + 'parameters': { + 'max_new_tokens': 20 + } + }) + } + ); +} + +query().then((response) => { + console.log(JSON.stringify(response)); +}); +/// {"generated_text":"\n\nDeep Learning is a subset of Machine Learning that is concerned with the development of algorithms that can"} +``` + + + + +```curl +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -H 'Content-Type: application/json' +``` + + + + + + +To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. + +```bash +docker run ghcr.io/huggingface/text-generation-inference:2.0.4 --help +``` + + diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md new file mode 100644 index 00000000..4b6cf731 --- /dev/null +++ b/docs/source/supported_models.md @@ -0,0 +1,48 @@ + +# Supported Models and Hardware + +Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. + +## Supported Models + +- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) +- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) +- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) +- [Gemma](https://huggingface.co/google/gemma-7b) +- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) +- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) +- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) +- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) +- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) +- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) +- [Phi](https://huggingface.co/microsoft/phi-1_5) +- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) +- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) +- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) +- [Qwen 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) +- [Opt](https://huggingface.co/facebook/opt-6.7b) +- [T5](https://huggingface.co/google/flan-t5-xxl) +- [Galactica](https://huggingface.co/facebook/galactica-120b) +- [SantaCoder](https://huggingface.co/bigcode/santacoder) +- [Bloom](https://huggingface.co/bigscience/bloom-560m) +- [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct) +- [Gpt2](https://huggingface.co/openai-community/gpt2) +- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) +- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal) + + +If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: + +```python +# for causal LMs/text-generation models +AutoModelForCausalLM.from_pretrained(, device_map="auto")` +# or, for text-to-text generation models +AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto") +``` + +If you wish to serve a supported model that already exists on a local folder, just point to the local folder. + +```bash +text-generation-launcher --model-id +``` diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py new file mode 100644 index 00000000..2ef85da6 --- /dev/null +++ b/integration-tests/conftest.py @@ -0,0 +1,517 @@ +import sys +import subprocess +import contextlib +import pytest +import asyncio +import os +import docker +import json +import math +import shutil +import tempfile +import time +import random + +from docker.errors import NotFound +from typing import Optional, List, Dict +from syrupy.extensions.json import JSONSnapshotExtension +from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError + +from text_generation import AsyncClient +from text_generation.types import ( + Response, + Details, + InputToken, + Token, + BestOfSequence, + Grammar, + ChatComplete, + ChatCompletionChunk, + ChatCompletionComplete, + Completion, +) + +DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) +HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) +DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") + + +class ResponseComparator(JSONSnapshotExtension): + rtol = 0.2 + ignore_logprob = False + + def serialize( + self, + data, + *, + exclude=None, + matcher=None, + ): + if ( + isinstance(data, Response) + or isinstance(data, ChatComplete) + or isinstance(data, ChatCompletionChunk) + or isinstance(data, ChatCompletionComplete) + ): + data = data.model_dump() + + if isinstance(data, List): + data = [d.model_dump() for d in data] + + data = self._filter( + data=data, depth=0, path=(), exclude=exclude, matcher=matcher + ) + return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n" + + def matches( + self, + *, + serialized_data, + snapshot_data, + ) -> bool: + def convert_data(data): + data = json.loads(data) + if isinstance(data, Dict) and "choices" in data: + choices = data["choices"] + if isinstance(choices, List) and len(choices) >= 1: + if "delta" in choices[0]: + return ChatCompletionChunk(**data) + if "text" in choices[0]: + return Completion(**data) + return ChatComplete(**data) + + if isinstance(data, Dict): + return Response(**data) + if isinstance(data, List): + if ( + len(data) > 0 + and "object" in data[0] + and data[0]["object"] == "text_completion" + ): + return [Completion(**d) for d in data] + return [Response(**d) for d in data] + raise NotImplementedError + + def eq_token(token: Token, other: Token) -> bool: + return ( + token.id == other.id + and token.text == other.text + and ( + self.ignore_logprob + or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) + ) + and token.special == other.special + ) + + def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool: + try: + return ( + prefill_token.id == other.id + and prefill_token.text == other.text + and ( + self.ignore_logprob + or math.isclose( + prefill_token.logprob, + other.logprob, + rel_tol=self.rtol, + ) + if prefill_token.logprob is not None + else prefill_token.logprob == other.logprob + ) + ) + except TypeError: + return False + + def eq_best_of(details: BestOfSequence, other: BestOfSequence) -> bool: + return ( + details.finish_reason == other.finish_reason + and details.generated_tokens == other.generated_tokens + and details.seed == other.seed + and len(details.prefill) == len(other.prefill) + and all( + [ + eq_prefill_token(d, o) + for d, o in zip(details.prefill, other.prefill) + ] + ) + and len(details.tokens) == len(other.tokens) + and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)]) + ) + + def eq_details(details: Details, other: Details) -> bool: + return ( + details.finish_reason == other.finish_reason + and details.generated_tokens == other.generated_tokens + and details.seed == other.seed + and len(details.prefill) == len(other.prefill) + and all( + [ + eq_prefill_token(d, o) + for d, o in zip(details.prefill, other.prefill) + ] + ) + and len(details.tokens) == len(other.tokens) + and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)]) + and ( + len(details.best_of_sequences) + if details.best_of_sequences is not None + else 0 + ) + == ( + len(other.best_of_sequences) + if other.best_of_sequences is not None + else 0 + ) + and ( + all( + [ + eq_best_of(d, o) + for d, o in zip( + details.best_of_sequences, other.best_of_sequences + ) + ] + ) + if details.best_of_sequences is not None + else details.best_of_sequences == other.best_of_sequences + ) + ) + + def eq_completion(response: Completion, other: Completion) -> bool: + return response.choices[0].text == other.choices[0].text + + def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: + return ( + response.choices[0].message.content == other.choices[0].message.content + ) + + def eq_chat_complete_chunk( + response: ChatCompletionChunk, other: ChatCompletionChunk + ) -> bool: + return response.choices[0].delta.content == other.choices[0].delta.content + + def eq_response(response: Response, other: Response) -> bool: + return response.generated_text == other.generated_text and eq_details( + response.details, other.details + ) + + serialized_data = convert_data(serialized_data) + snapshot_data = convert_data(snapshot_data) + + if not isinstance(serialized_data, List): + serialized_data = [serialized_data] + if not isinstance(snapshot_data, List): + snapshot_data = [snapshot_data] + + if isinstance(serialized_data[0], Completion): + return len(snapshot_data) == len(serialized_data) and all( + [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + + if isinstance(serialized_data[0], ChatComplete): + return len(snapshot_data) == len(serialized_data) and all( + [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + + if isinstance(serialized_data[0], ChatCompletionChunk): + return len(snapshot_data) == len(serialized_data) and all( + [ + eq_chat_complete_chunk(r, o) + for r, o in zip(serialized_data, snapshot_data) + ] + ) + + return len(snapshot_data) == len(serialized_data) and all( + [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + + +class GenerousResponseComparator(ResponseComparator): + # Needed for GPTQ with exllama which has serious numerical fluctuations. + rtol = 0.75 + + +class IgnoreLogProbResponseComparator(ResponseComparator): + ignore_logprob = True + + +class LauncherHandle: + def __init__(self, port: int): + self.client = AsyncClient(f"http://localhost:{port}") + + def _inner_health(self): + raise NotImplementedError + + async def health(self, timeout: int = 60): + assert timeout > 0 + for _ in range(timeout): + if not self._inner_health(): + raise RuntimeError("Launcher crashed") + + try: + await self.client.generate("test") + return + except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: + time.sleep(1) + raise RuntimeError("Health check failed") + + +class ContainerLauncherHandle(LauncherHandle): + def __init__(self, docker_client, container_name, port: int): + super(ContainerLauncherHandle, self).__init__(port) + self.docker_client = docker_client + self.container_name = container_name + + def _inner_health(self) -> bool: + container = self.docker_client.containers.get(self.container_name) + return container.status in ["running", "created"] + + +class ProcessLauncherHandle(LauncherHandle): + def __init__(self, process, port: int): + super(ProcessLauncherHandle, self).__init__(port) + self.process = process + + def _inner_health(self) -> bool: + return self.process.poll() is None + + +@pytest.fixture +def response_snapshot(snapshot): + return snapshot.use_extension(ResponseComparator) + + +@pytest.fixture +def generous_response_snapshot(snapshot): + return snapshot.use_extension(GenerousResponseComparator) + + +@pytest.fixture +def ignore_logprob_response_snapshot(snapshot): + return snapshot.use_extension(IgnoreLogProbResponseComparator) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="module") +def launcher(event_loop): + @contextlib.contextmanager + def local_launcher( + model_id: str, + num_shard: Optional[int] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + use_flash_attention: bool = True, + disable_grammar_support: bool = False, + dtype: Optional[str] = None, + revision: Optional[str] = None, + max_input_length: Optional[int] = None, + max_batch_prefill_tokens: Optional[int] = None, + max_total_tokens: Optional[int] = None, + ): + port = random.randint(8000, 10_000) + master_port = random.randint(10_000, 20_000) + + shard_uds_path = ( + f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server" + ) + + args = [ + "text-generation-launcher", + "--model-id", + model_id, + "--port", + str(port), + "--master-port", + str(master_port), + "--shard-uds-path", + shard_uds_path, + ] + + env = os.environ + + if disable_grammar_support: + args.append("--disable-grammar-support") + if num_shard is not None: + args.extend(["--num-shard", str(num_shard)]) + if quantize is not None: + args.append("--quantize") + args.append(quantize) + if dtype is not None: + args.append("--dtype") + args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) + if trust_remote_code: + args.append("--trust-remote-code") + if max_input_length: + args.append("--max-input-length") + args.append(str(max_input_length)) + if max_batch_prefill_tokens: + args.append("--max-batch-prefill-tokens") + args.append(str(max_batch_prefill_tokens)) + if max_total_tokens: + args.append("--max-total-tokens") + args.append(str(max_total_tokens)) + + env["LOG_LEVEL"] = "info,text_generation_router=debug" + + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + + with tempfile.TemporaryFile("w+") as tmp: + # We'll output stdout/stderr to a temporary file. Using a pipe + # cause the process to block until stdout is read. + with subprocess.Popen( + args, + stdout=tmp, + stderr=subprocess.STDOUT, + env=env, + ) as process: + yield ProcessLauncherHandle(process, port) + + process.terminate() + process.wait(60) + + tmp.seek(0) + shutil.copyfileobj(tmp, sys.stderr) + + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + + @contextlib.contextmanager + def docker_launcher( + model_id: str, + num_shard: Optional[int] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + use_flash_attention: bool = True, + disable_grammar_support: bool = False, + dtype: Optional[str] = None, + revision: Optional[str] = None, + max_input_length: Optional[int] = None, + max_batch_prefill_tokens: Optional[int] = None, + max_total_tokens: Optional[int] = None, + ): + port = random.randint(8000, 10_000) + + args = ["--model-id", model_id, "--env"] + + if disable_grammar_support: + args.append("--disable-grammar-support") + if num_shard is not None: + args.extend(["--num-shard", str(num_shard)]) + if quantize is not None: + args.append("--quantize") + args.append(quantize) + if dtype is not None: + args.append("--dtype") + args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) + if trust_remote_code: + args.append("--trust-remote-code") + if max_input_length: + args.append("--max-input-length") + args.append(str(max_input_length)) + if max_batch_prefill_tokens: + args.append("--max-batch-prefill-tokens") + args.append(str(max_batch_prefill_tokens)) + if max_total_tokens: + args.append("--max-total-tokens") + args.append(str(max_total_tokens)) + + client = docker.from_env() + + container_name = f"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}" + + try: + container = client.containers.get(container_name) + container.stop() + container.wait() + except NotFound: + pass + + gpu_count = num_shard if num_shard is not None else 1 + + env = { + "LOG_LEVEL": "info,text_generation_router=debug", + } + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + + if HUGGING_FACE_HUB_TOKEN is not None: + env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN + + volumes = [] + if DOCKER_VOLUME: + volumes = [f"{DOCKER_VOLUME}:/data"] + + container = client.containers.run( + DOCKER_IMAGE, + command=args, + name=container_name, + environment=env, + auto_remove=False, + detach=True, + device_requests=[ + docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) + ], + volumes=volumes, + ports={"80/tcp": port}, + shm_size="1G", + ) + + yield ContainerLauncherHandle(client, container.name, port) + + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + + try: + container.stop() + container.wait() + except NotFound: + pass + + container_output = container.logs().decode("utf-8") + print(container_output, file=sys.stderr) + + container.remove() + + if DOCKER_IMAGE is not None: + return docker_launcher + return local_launcher + + +@pytest.fixture(scope="module") +def generate_load(): + async def generate_load_inner( + client: AsyncClient, + prompt: str, + max_new_tokens: int, + n: int, + seed: Optional[int] = None, + grammar: Optional[Grammar] = None, + stop_sequences: Optional[List[str]] = None, + ) -> List[Response]: + futures = [ + client.generate( + prompt, + max_new_tokens=max_new_tokens, + decoder_input_details=True, + seed=seed, + grammar=grammar, + stop_sequences=stop_sequences, + ) + for _ in range(n) + ] + + return await asyncio.gather(*futures) + + return generate_load_inner diff --git a/integration-tests/images/chicken_on_money.png b/integration-tests/images/chicken_on_money.png new file mode 100644 index 00000000..1a4e0440 Binary files /dev/null and b/integration-tests/images/chicken_on_money.png differ diff --git a/integration-tests/images/cow_beach.png b/integration-tests/images/cow_beach.png new file mode 100644 index 00000000..d67f8a1b Binary files /dev/null and b/integration-tests/images/cow_beach.png differ diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json new file mode 100644 index 00000000..53a4ab85 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json @@ -0,0 +1,128 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.5625, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.14770508, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.4609375, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.5585938, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.4003906, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5673828, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94628906, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.703125, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.5732422, + "text": " d'abord" + } + ], + "seed": 0, + "tokens": [ + { + "id": 578, + "logprob": -1.6591797, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.4492188, + "special": false, + "text": " faire" + }, + { + "id": 159570, + "logprob": -6.6835938, + "special": false, + "text": " réch" + }, + { + "id": 810, + "logprob": 0.0, + "special": false, + "text": "au" + }, + { + "id": 12736, + "logprob": 0.0, + "special": false, + "text": "ffer" + }, + { + "id": 1742, + "logprob": -2.5175781, + "special": false, + "text": " au" + }, + { + "id": 6105, + "logprob": -2.0078125, + "special": false, + "text": " bain" + }, + { + "id": 88254, + "logprob": -0.12695312, + "special": false, + "text": "-mar" + }, + { + "id": 641, + "logprob": 0.0, + "special": false, + "text": "ie" + }, + { + "id": 2940, + "logprob": -3.5175781, + "special": false, + "text": " avec" + } + ] + }, + "generated_text": " le faire réchauffer au bain-marie avec" +} diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json new file mode 100644 index 00000000..ace73416 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json @@ -0,0 +1,98 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 15, + "logprob": null, + "text": "," + }, + { + "id": 1669, + "logprob": -5.4414062, + "text": " il" + }, + { + "id": 11580, + "logprob": -2.3378906, + "text": " faut" + }, + { + "id": 3913, + "logprob": -4.3554688, + "text": " tout" + }, + { + "id": 39261, + "logprob": -2.9238281, + "text": " d'abord" + } + ], + "seed": 0, + "tokens": [ + { + "id": 408, + "logprob": -0.07891846, + "special": false, + "text": " que" + }, + { + "id": 366, + "logprob": -1.2939453, + "special": false, + "text": " la" + }, + { + "id": 8769, + "logprob": -0.3708496, + "special": false, + "text": " personne" + }, + { + "id": 1479, + "logprob": -2.2871094, + "special": false, + "text": " qui" + }, + { + "id": 2997, + "logprob": -0.8671875, + "special": false, + "text": " vous" + }, + { + "id": 35977, + "logprob": -1.5097656, + "special": false, + "text": " suit" + }, + { + "id": 21558, + "logprob": -0.07891846, + "special": false, + "text": " ait" + }, + { + "id": 447, + "logprob": -0.12695312, + "special": false, + "text": " un" + }, + { + "id": 78606, + "logprob": -2.21875, + "special": false, + "text": " profil" + }, + { + "id": 3899, + "logprob": -1.3535156, + "special": false, + "text": " bien" + } + ] + }, + "generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien" +} diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_load.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_load.json new file mode 100644 index 00000000..0a86bef8 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_load.json @@ -0,0 +1,514 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.5625, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.14770508, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.4609375, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.5585938, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.4003906, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5673828, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94628906, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.703125, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.5732422, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7646484, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.6113281, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5263672, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -0.00010049343, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.4707031, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.2119141, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11883545, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.40844727, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.0037841797, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0195312, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.53125, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.14770508, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.4140625, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.5234375, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.3613281, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5458984, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94189453, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.7011719, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.5732422, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7548828, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.578125, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5117188, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -0.00010049343, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.4707031, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.1982422, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11004639, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.4506836, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.003047943, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0185547, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.53125, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.14770508, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.4140625, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.5234375, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.3613281, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5458984, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94189453, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.7011719, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.5732422, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7548828, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.578125, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5117188, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -0.00010049343, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.4707031, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.1982422, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11004639, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.4506836, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.003047943, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0185547, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.53125, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.14770508, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.4140625, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.5234375, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.3613281, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5458984, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94189453, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.7011719, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.5732422, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7548828, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.578125, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5117188, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -0.00010049343, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.4707031, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.1982422, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11004639, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.4506836, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.003047943, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0185547, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + } +] diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json new file mode 100644 index 00000000..dd8936af --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json @@ -0,0 +1,128 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.5390625, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.14758301, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9296875, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.4453125, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.59375, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.3994141, + "text": "," + }, + { + "id": 1669, + "logprob": -1.578125, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.9453125, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.7011719, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.5732422, + "text": " d'abord" + } + ], + "seed": 0, + "tokens": [ + { + "id": 578, + "logprob": -1.6474609, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.5097656, + "special": false, + "text": " faire" + }, + { + "id": 159570, + "logprob": -6.65625, + "special": false, + "text": " réch" + }, + { + "id": 810, + "logprob": 0.0, + "special": false, + "text": "au" + }, + { + "id": 12736, + "logprob": 0.0, + "special": false, + "text": "ffer" + }, + { + "id": 1742, + "logprob": -2.5859375, + "special": false, + "text": " au" + }, + { + "id": 6105, + "logprob": -2.03125, + "special": false, + "text": " bain" + }, + { + "id": 88254, + "logprob": -0.12695312, + "special": false, + "text": "-mar" + }, + { + "id": 641, + "logprob": 0.0, + "special": false, + "text": "ie" + }, + { + "id": 2940, + "logprob": -3.5175781, + "special": false, + "text": " avec" + } + ] + }, + "generated_text": " le faire réchauffer au bain-marie avec" +} diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded_load.json b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded_load.json new file mode 100644 index 00000000..2dd480b9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded_load.json @@ -0,0 +1,514 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.5390625, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.14758301, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9296875, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.4453125, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.59375, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.3994141, + "text": "," + }, + { + "id": 1669, + "logprob": -1.578125, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.9453125, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.7011719, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.5732422, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7529297, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.6054688, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5283203, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -0.00010049343, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.4716797, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.1982422, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11853027, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.41210938, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.0037765503, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0166016, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.515625, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.1484375, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.34375, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.515625, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.4199219, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5664062, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94091797, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.6660156, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.7753906, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7626953, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.5820312, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5097656, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -9.393692e-05, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.5175781, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.1982422, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11883545, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.4909668, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.003047943, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0185547, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.515625, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.1484375, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.34375, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.515625, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.4199219, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5664062, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94091797, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.6660156, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.7753906, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7626953, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.5820312, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5097656, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -9.393692e-05, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.5175781, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.1982422, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11883545, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.4909668, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.003047943, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0185547, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 17934, + "logprob": null, + "text": "Pour" + }, + { + "id": 49833, + "logprob": -10.515625, + "text": " dég" + }, + { + "id": 21543, + "logprob": -0.1484375, + "text": "uster" + }, + { + "id": 447, + "logprob": -1.9287109, + "text": " un" + }, + { + "id": 46341, + "logprob": -15.34375, + "text": " ort" + }, + { + "id": 35567, + "logprob": -7.515625, + "text": "olan" + }, + { + "id": 15, + "logprob": -1.4199219, + "text": "," + }, + { + "id": 1669, + "logprob": -1.5664062, + "text": " il" + }, + { + "id": 11580, + "logprob": -0.94091797, + "text": " faut" + }, + { + "id": 3913, + "logprob": -3.6660156, + "text": " tout" + }, + { + "id": 39261, + "logprob": -1.7753906, + "text": " d'abord" + } + ], + "seed": null, + "tokens": [ + { + "id": 578, + "logprob": -1.7626953, + "special": false, + "text": " le" + }, + { + "id": 5608, + "logprob": -2.5820312, + "special": false, + "text": " faire" + }, + { + "id": 1767, + "logprob": -1.5097656, + "special": false, + "text": " cu" + }, + { + "id": 1273, + "logprob": -9.393692e-05, + "special": false, + "text": "ire" + }, + { + "id": 1486, + "logprob": -1.5175781, + "special": false, + "text": " dans" + }, + { + "id": 283, + "logprob": -1.1982422, + "special": false, + "text": " de" + }, + { + "id": 40410, + "logprob": -0.11883545, + "special": false, + "text": " l'eau" + }, + { + "id": 20226, + "logprob": -0.4909668, + "special": false, + "text": " bou" + }, + { + "id": 172483, + "logprob": -0.003047943, + "special": false, + "text": "illante" + }, + { + "id": 2805, + "logprob": -1.0185547, + "special": false, + "text": " sal" + } + ] + }, + "generated_text": " le faire cuire dans de l'eau bouillante sal" + } +] diff --git a/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json new file mode 100644 index 00000000..8631c076 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1716553098, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.5-dev0-native", + "usage": { + "completion_tokens": 100, + "prompt_tokens": 62, + "total_tokens": 162 + } +} diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json new file mode 100644 index 00000000..99c33cf7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 1, + "logprobs": null, + "text": " PR for more information?" + }, + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": "le Business Incubator is providing a workspace" + }, + { + "finish_reason": "length", + "index": 2, + "logprobs": null, + "text": " severely flawed and often has a substandard" + }, + { + "finish_reason": "length", + "index": 3, + "logprobs": null, + "text": "hd20220811-" + } + ], + "created": 1713284455, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native", + "usage": { + "completion_tokens": 36, + "prompt_tokens": 8, + "total_tokens": 44 + } +} diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json new file mode 100644 index 00000000..d87071cf --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json @@ -0,0 +1,602 @@ +[ + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "hd" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "aho" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "2" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "2" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "2" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "ima" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "." + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "." + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "." + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": " Sarah" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": " Yes" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " And" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "i" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "'" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "," + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " what" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "'" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "s" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": " Moh" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " is" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "m" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": " Room" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "s" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " the" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": " tired" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": ":" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "'" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " capital" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": " of" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": " She" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": " scale" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " of" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": " being" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" + } +] diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json new file mode 100644 index 00000000..5aed4935 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": " PR for flake8" + } + ], + "created": 1713284454, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native", + "usage": { + "completion_tokens": 5, + "prompt_tokens": 6, + "total_tokens": 11 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq.json b/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq.json new file mode 100644 index 00000000..dcd37cb9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq.json @@ -0,0 +1,104 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.703125, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4765625, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8583984, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7548828, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9306641, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4550781, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.5732422, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5761719, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5888672, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.026504517, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4287109, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15856934, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62646484, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" +} diff --git a/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_all_params.json b/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_all_params.json new file mode 100644 index 00000000..d16d34f9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 338, + "logprob": -9.0859375, + "text": "is" + }, + { + "id": 21784, + "logprob": -10.90625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -2.65625, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -4.8085938, + "text": "?" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -0.19958496, + "special": false, + "text": "\n" + }, + { + "id": 4013, + "logprob": -2.203125, + "special": false, + "text": "This" + }, + { + "id": 1139, + "logprob": -0.23693848, + "special": false, + "text": " question" + }, + { + "id": 756, + "logprob": 0.0, + "special": false, + "text": " has" + }, + { + "id": 1063, + "logprob": -0.076538086, + "special": false, + "text": " been" + }, + { + "id": 4433, + "logprob": 0.0, + "special": false, + "text": " asked" + }, + { + "id": 1784, + "logprob": -1.1367188, + "special": false, + "text": " many" + }, + { + "id": 3064, + "logprob": 0.0, + "special": false, + "text": " times" + }, + { + "id": 322, + "logprob": -1.7460938, + "special": false, + "text": " and" + }, + { + "id": 306, + "logprob": 0.0, + "special": false, + "text": " I" + } + ], + "top_tokens": null + }, + "generated_text": "What is Deep Learning?\nThis question has been asked many times and I" +} diff --git a/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_load.json b/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_load.json new file mode 100644 index 00000000..e6fb3dc0 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_load.json @@ -0,0 +1,418 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.703125, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4765625, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8652344, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7548828, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9306641, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4550781, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.5732422, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5761719, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5888672, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.026504517, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4287109, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15856934, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62646484, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.703125, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4765625, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8583984, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7548828, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9306641, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4550781, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.5732422, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5761719, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5888672, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.026504517, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4287109, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15856934, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62646484, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.703125, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4765625, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8652344, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7548828, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9306641, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4550781, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.5732422, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5761719, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5888672, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.026504517, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4287109, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15856934, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62646484, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.703125, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4765625, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8652344, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7548828, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9306641, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4550781, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.5732422, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5761719, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5888672, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.026504517, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4287109, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15856934, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62646484, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_load_sharded.json b/integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_load_sharded.json new file mode 100644 index 00000000..f1d9129d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_load_sharded.json @@ -0,0 +1,418 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.6914062, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4746094, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8623047, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7558594, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9228516, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4609375, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.57177734, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5722656, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5859375, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.02633667, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4335938, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15991211, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62060547, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.6914062, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4746094, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8623047, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7558594, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9228516, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4609375, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.57177734, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5722656, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5859375, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.02633667, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4335938, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15991211, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62060547, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.6914062, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4746094, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8623047, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7558594, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9228516, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4609375, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.57177734, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5722656, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5859375, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.02633667, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4335938, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15991211, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62060547, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.6914062, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4746094, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8623047, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7558594, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9228516, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4609375, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.57177734, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5722656, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5859375, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.02633667, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4335938, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.15991211, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17456055, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62060547, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_sharded.json b/integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_sharded.json new file mode 100644 index 00000000..0f91eb36 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_sharded.json @@ -0,0 +1,104 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -7.6914062, + "text": "What" + }, + { + "id": 338, + "logprob": -1.4746094, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.390625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.8623047, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.7558594, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.9228516, + "special": false, + "text": "\n" + }, + { + "id": 5618, + "logprob": -2.4609375, + "special": false, + "text": "What" + }, + { + "id": 338, + "logprob": -0.57177734, + "special": false, + "text": " is" + }, + { + "id": 278, + "logprob": -1.5722656, + "special": false, + "text": " the" + }, + { + "id": 4328, + "logprob": -1.5927734, + "special": false, + "text": " difference" + }, + { + "id": 1546, + "logprob": -0.026428223, + "special": false, + "text": " between" + }, + { + "id": 21784, + "logprob": -1.4267578, + "special": false, + "text": " Deep" + }, + { + "id": 29257, + "logprob": -0.16015625, + "special": false, + "text": " Learning" + }, + { + "id": 322, + "logprob": -0.17382812, + "special": false, + "text": " and" + }, + { + "id": 6189, + "logprob": -0.62060547, + "special": false, + "text": " Machine" + } + ], + "top_tokens": null + }, + "generated_text": "\nWhat is the difference between Deep Learning and Machine" +} diff --git a/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon.json b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon.json new file mode 100644 index 00000000..488f3de3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon.json @@ -0,0 +1,378 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.96875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.6132812, + "text": "af" + }, + { + "id": 249, + "logprob": -6.5039062, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.078125, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.3261719, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.59375, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048339844, + "text": " with" + }, + { + "id": 26680, + "logprob": -4.0, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.07556152, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.0067749023, + "text": "es" + }, + { + "id": 23, + "logprob": -1.546875, + "text": "," + }, + { + "id": 248, + "logprob": -4.3320312, + "text": " the" + }, + { + "id": 758, + "logprob": -3.734375, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.109375, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.09375, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1835938, + "text": " on" + }, + { + "id": 248, + "logprob": -0.77685547, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.3828125, + "text": " face" + }, + { + "id": 275, + "logprob": -0.004432678, + "text": " of" + }, + { + "id": 414, + "logprob": -1.9677734, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.046875, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28198242, + "text": "." + }, + { + "id": 401, + "logprob": -7.9179688, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.2753906, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.6230469, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.20874023, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.5507812, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5664062, + "text": " all" + }, + { + "id": 599, + "logprob": -2.7402344, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21948242, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.7675781, + "text": " are" + }, + { + "id": 23981, + "logprob": -5.0, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.234375, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.5131836, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.103637695, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58447266, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6835938, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8173828, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.23510742, + "text": " of" + }, + { + "id": 248, + "logprob": -0.35473633, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24633789, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.02960205, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17333984, + "text": "." + }, + { + "id": 193, + "logprob": -1.3935547, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.99365234, + "text": "," + }, + { + "id": 29033, + "logprob": -2.2324219, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10809326, + "text": "af" + }, + { + "id": 249, + "logprob": -0.042663574, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0024776459, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4277344, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1015625, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.05709839, + "text": "G" + }, + { + "id": 330, + "logprob": -0.13208008, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.0071487427, + "text": "af" + }, + { + "id": 249, + "logprob": -0.008468628, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.00068998337, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074691772, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8251953, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.3173828, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.23803711, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.56933594, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.61279297, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.41967773, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.0023403168, + "special": false, + "text": ":" + }, + { + "id": 1634, + "logprob": -2.0605469, + "special": false, + "text": " What" + }, + { + "id": 18, + "logprob": -1.5292969, + "special": false, + "text": "'" + }, + { + "id": 94, + "logprob": -0.007904053, + "special": false, + "text": "s" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: What's" +} diff --git a/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_all_params.json b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_all_params.json new file mode 100644 index 00000000..cd35186d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_all_params.json @@ -0,0 +1,98 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 330, + "logprob": null, + "text": "ir" + }, + { + "id": 1622, + "logprob": -7.8125, + "text": "af" + }, + { + "id": 249, + "logprob": -4.5, + "text": "at" + }, + { + "id": 1480, + "logprob": -10.875, + "text": "ron" + }, + { + "id": 37, + "logprob": -3.6875, + "text": ":" + } + ], + "seed": 0, + "tokens": [ + { + "id": 836, + "logprob": -1.265625, + "special": false, + "text": " i" + }, + { + "id": 18, + "logprob": -0.119628906, + "special": false, + "text": "'" + }, + { + "id": 298, + "logprob": -2.265625, + "special": false, + "text": "ve" + }, + { + "id": 650, + "logprob": -0.49804688, + "special": false, + "text": " been" + }, + { + "id": 1241, + "logprob": 0.0, + "special": false, + "text": " using" + }, + { + "id": 334, + "logprob": 0.0, + "special": false, + "text": " it" + }, + { + "id": 312, + "logprob": -1.2421875, + "special": false, + "text": " for" + }, + { + "id": 909, + "logprob": -0.99609375, + "special": false, + "text": " years" + }, + { + "id": 193, + "logprob": -0.30273438, + "special": false, + "text": "\n" + }, + { + "id": 807, + "logprob": -1.078125, + "special": false, + "text": "ik" + } + ] + }, + "generated_text": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron: i've been using it for years\nik" +} diff --git a/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_load.json b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_load.json new file mode 100644 index 00000000..90a35eb7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_load.json @@ -0,0 +1,1514 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.96875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.6171875, + "text": "af" + }, + { + "id": 249, + "logprob": -6.5039062, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0703125, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.59375, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.04837036, + "text": " with" + }, + { + "id": 26680, + "logprob": -3.9960938, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.07525635, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.006790161, + "text": "es" + }, + { + "id": 23, + "logprob": -1.546875, + "text": "," + }, + { + "id": 248, + "logprob": -4.3320312, + "text": " the" + }, + { + "id": 758, + "logprob": -3.7363281, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.109375, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.09375, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1845703, + "text": " on" + }, + { + "id": 248, + "logprob": -0.77734375, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.3828125, + "text": " face" + }, + { + "id": 275, + "logprob": -0.0044403076, + "text": " of" + }, + { + "id": 414, + "logprob": -1.9667969, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.0449219, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28198242, + "text": "." + }, + { + "id": 401, + "logprob": -7.921875, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.2714844, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.62353516, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.20947266, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.5507812, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5625, + "text": " all" + }, + { + "id": 599, + "logprob": -2.7402344, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21899414, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.76708984, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.9960938, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.234375, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.5131836, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.103515625, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58447266, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6796875, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8222656, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.23583984, + "text": " of" + }, + { + "id": 248, + "logprob": -0.3544922, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24609375, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.02960205, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17358398, + "text": "." + }, + { + "id": 193, + "logprob": -1.3925781, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.5898438, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.99365234, + "text": "," + }, + { + "id": 29033, + "logprob": -2.2304688, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.107788086, + "text": "af" + }, + { + "id": 249, + "logprob": -0.04257202, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0024871826, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4277344, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1005859, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.056915283, + "text": "G" + }, + { + "id": 330, + "logprob": -0.1315918, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.0071105957, + "text": "af" + }, + { + "id": 249, + "logprob": -0.008453369, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0006928444, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074920654, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.828125, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.3178711, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.23925781, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5698242, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.61279297, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.4177246, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.0023345947, + "special": false, + "text": ":" + }, + { + "id": 1634, + "logprob": -2.0605469, + "special": false, + "text": " What" + }, + { + "id": 18, + "logprob": -1.5283203, + "special": false, + "text": "'" + }, + { + "id": 94, + "logprob": -0.007965088, + "special": false, + "text": "s" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: What's" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.96875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.6171875, + "text": "af" + }, + { + "id": 249, + "logprob": -6.5, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0703125, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.59375, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048339844, + "text": " with" + }, + { + "id": 26680, + "logprob": -4.0, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.07531738, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.006793976, + "text": "es" + }, + { + "id": 23, + "logprob": -1.5478516, + "text": "," + }, + { + "id": 248, + "logprob": -4.3320312, + "text": " the" + }, + { + "id": 758, + "logprob": -3.7363281, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.1132812, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.0957031, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1835938, + "text": " on" + }, + { + "id": 248, + "logprob": -0.77685547, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.3808594, + "text": " face" + }, + { + "id": 275, + "logprob": -0.004436493, + "text": " of" + }, + { + "id": 414, + "logprob": -1.9638672, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.0449219, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28198242, + "text": "." + }, + { + "id": 401, + "logprob": -7.9179688, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.2734375, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.6230469, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.20947266, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.5546875, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5703125, + "text": " all" + }, + { + "id": 599, + "logprob": -2.7382812, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21948242, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.7661133, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.9960938, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.234375, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.5131836, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.10357666, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58447266, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6816406, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8203125, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.23583984, + "text": " of" + }, + { + "id": 248, + "logprob": -0.35473633, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24572754, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.029586792, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17346191, + "text": "." + }, + { + "id": 193, + "logprob": -1.3945312, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.99316406, + "text": "," + }, + { + "id": 29033, + "logprob": -2.2324219, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10797119, + "text": "af" + }, + { + "id": 249, + "logprob": -0.04248047, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0024814606, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4277344, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1005859, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.056884766, + "text": "G" + }, + { + "id": 330, + "logprob": -0.1315918, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.007095337, + "text": "af" + }, + { + "id": 249, + "logprob": -0.00844574, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.00068998337, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074768066, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8251953, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.31762695, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.2388916, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5698242, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.6152344, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.42211914, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.002336502, + "special": false, + "text": ":" + }, + { + "id": 1634, + "logprob": -2.0605469, + "special": false, + "text": " What" + }, + { + "id": 18, + "logprob": -1.5292969, + "special": false, + "text": "'" + }, + { + "id": 94, + "logprob": -0.007926941, + "special": false, + "text": "s" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: What's" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.96875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.6171875, + "text": "af" + }, + { + "id": 249, + "logprob": -6.5, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0703125, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.59375, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048339844, + "text": " with" + }, + { + "id": 26680, + "logprob": -4.0, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.07531738, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.006793976, + "text": "es" + }, + { + "id": 23, + "logprob": -1.5478516, + "text": "," + }, + { + "id": 248, + "logprob": -4.3320312, + "text": " the" + }, + { + "id": 758, + "logprob": -3.7363281, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.1132812, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.0957031, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1835938, + "text": " on" + }, + { + "id": 248, + "logprob": -0.77685547, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.3808594, + "text": " face" + }, + { + "id": 275, + "logprob": -0.004436493, + "text": " of" + }, + { + "id": 414, + "logprob": -1.9638672, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.0449219, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28198242, + "text": "." + }, + { + "id": 401, + "logprob": -7.9179688, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.2734375, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.6230469, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.20947266, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.5546875, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5703125, + "text": " all" + }, + { + "id": 599, + "logprob": -2.7382812, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21948242, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.7661133, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.9960938, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.234375, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.5131836, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.10357666, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58447266, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6816406, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8203125, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.23583984, + "text": " of" + }, + { + "id": 248, + "logprob": -0.35473633, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24572754, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.029586792, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17346191, + "text": "." + }, + { + "id": 193, + "logprob": -1.3945312, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.99316406, + "text": "," + }, + { + "id": 29033, + "logprob": -2.2324219, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10797119, + "text": "af" + }, + { + "id": 249, + "logprob": -0.04248047, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0024814606, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4277344, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1005859, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.056884766, + "text": "G" + }, + { + "id": 330, + "logprob": -0.1315918, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.007095337, + "text": "af" + }, + { + "id": 249, + "logprob": -0.00844574, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.00068998337, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074768066, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8251953, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.31762695, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.2388916, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5698242, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.6152344, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.42211914, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.002336502, + "special": false, + "text": ":" + }, + { + "id": 1634, + "logprob": -2.0605469, + "special": false, + "text": " What" + }, + { + "id": 18, + "logprob": -1.5292969, + "special": false, + "text": "'" + }, + { + "id": 94, + "logprob": -0.007926941, + "special": false, + "text": "s" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: What's" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50, + "logprob": null, + "text": "G" + }, + { + "id": 330, + "logprob": -5.96875, + "text": "ir" + }, + { + "id": 1622, + "logprob": -5.6171875, + "text": "af" + }, + { + "id": 249, + "logprob": -6.5, + "text": "at" + }, + { + "id": 1480, + "logprob": -8.0703125, + "text": "ron" + }, + { + "id": 304, + "logprob": -2.328125, + "text": " is" + }, + { + "id": 23866, + "logprob": -9.59375, + "text": " obsessed" + }, + { + "id": 335, + "logprob": -0.048339844, + "text": " with" + }, + { + "id": 26680, + "logprob": -4.0, + "text": " gir" + }, + { + "id": 1903, + "logprob": -0.07531738, + "text": "aff" + }, + { + "id": 255, + "logprob": -0.006793976, + "text": "es" + }, + { + "id": 23, + "logprob": -1.5478516, + "text": "," + }, + { + "id": 248, + "logprob": -4.3320312, + "text": " the" + }, + { + "id": 758, + "logprob": -3.7363281, + "text": " most" + }, + { + "id": 21735, + "logprob": -5.1132812, + "text": " glorious" + }, + { + "id": 5985, + "logprob": -2.0957031, + "text": " animal" + }, + { + "id": 313, + "logprob": -1.1835938, + "text": " on" + }, + { + "id": 248, + "logprob": -0.77685547, + "text": " the" + }, + { + "id": 1936, + "logprob": -2.3808594, + "text": " face" + }, + { + "id": 275, + "logprob": -0.004436493, + "text": " of" + }, + { + "id": 414, + "logprob": -1.9638672, + "text": " this" + }, + { + "id": 6490, + "logprob": -2.0449219, + "text": " Earth" + }, + { + "id": 25, + "logprob": -0.28198242, + "text": "." + }, + { + "id": 401, + "logprob": -7.9179688, + "text": " G" + }, + { + "id": 6013, + "logprob": -2.2734375, + "text": "ira" + }, + { + "id": 694, + "logprob": -0.6230469, + "text": "ft" + }, + { + "id": 1480, + "logprob": -0.20947266, + "text": "ron" + }, + { + "id": 9369, + "logprob": -4.5546875, + "text": " believes" + }, + { + "id": 455, + "logprob": -4.5703125, + "text": " all" + }, + { + "id": 599, + "logprob": -2.7382812, + "text": " other" + }, + { + "id": 5632, + "logprob": -0.21948242, + "text": " animals" + }, + { + "id": 362, + "logprob": -0.7661133, + "text": " are" + }, + { + "id": 23981, + "logprob": -4.9960938, + "text": " irrelevant" + }, + { + "id": 635, + "logprob": -4.234375, + "text": " when" + }, + { + "id": 4354, + "logprob": -0.5131836, + "text": " compared" + }, + { + "id": 271, + "logprob": -0.10357666, + "text": " to" + }, + { + "id": 248, + "logprob": -0.58447266, + "text": " the" + }, + { + "id": 21735, + "logprob": -3.6816406, + "text": " glorious" + }, + { + "id": 64398, + "logprob": -1.8203125, + "text": " majesty" + }, + { + "id": 275, + "logprob": -0.23583984, + "text": " of" + }, + { + "id": 248, + "logprob": -0.35473633, + "text": " the" + }, + { + "id": 26680, + "logprob": -0.24572754, + "text": " gir" + }, + { + "id": 23226, + "logprob": -0.029586792, + "text": "affe" + }, + { + "id": 25, + "logprob": -0.17346191, + "text": "." + }, + { + "id": 193, + "logprob": -1.3945312, + "text": "\n" + }, + { + "id": 23626, + "logprob": -10.0625, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -4.59375, + "text": ":" + }, + { + "id": 23090, + "logprob": -6.9375, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.99316406, + "text": "," + }, + { + "id": 29033, + "logprob": -2.2324219, + "text": " Gir" + }, + { + "id": 1622, + "logprob": -0.10797119, + "text": "af" + }, + { + "id": 249, + "logprob": -0.04248047, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.0024814606, + "text": "ron" + }, + { + "id": 12, + "logprob": -1.4277344, + "text": "!" + }, + { + "id": 193, + "logprob": -1.1005859, + "text": "\n" + }, + { + "id": 50, + "logprob": -0.056884766, + "text": "G" + }, + { + "id": 330, + "logprob": -0.1315918, + "text": "ir" + }, + { + "id": 1622, + "logprob": -0.007095337, + "text": "af" + }, + { + "id": 249, + "logprob": -0.00844574, + "text": "at" + }, + { + "id": 1480, + "logprob": -0.00068998337, + "text": "ron" + }, + { + "id": 37, + "logprob": -0.0074768066, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 23090, + "logprob": -1.8251953, + "special": false, + "text": " Hello" + }, + { + "id": 23, + "logprob": -0.31762695, + "special": false, + "text": "," + }, + { + "id": 8156, + "logprob": -0.2388916, + "special": false, + "text": " Daniel" + }, + { + "id": 12, + "logprob": -0.5698242, + "special": false, + "text": "!" + }, + { + "id": 193, + "logprob": -0.6152344, + "special": false, + "text": "\n" + }, + { + "id": 23626, + "logprob": -0.42211914, + "special": false, + "text": "Daniel" + }, + { + "id": 37, + "logprob": -0.002336502, + "special": false, + "text": ":" + }, + { + "id": 1634, + "logprob": -2.0605469, + "special": false, + "text": " What" + }, + { + "id": 18, + "logprob": -1.5292969, + "special": false, + "text": "'" + }, + { + "id": 94, + "logprob": -0.007926941, + "special": false, + "text": "s" + } + ] + }, + "generated_text": " Hello, Daniel!\nDaniel: What's" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json new file mode 100644 index 00000000..80f0d053 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.8671875, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.4375, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8203125, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23242188, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.08544922, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.9375, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.671875, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.40429688, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.1875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json new file mode 100644 index 00000000..8253dc96 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 7539, + "logprob": -0.73046875, + "special": false, + "text": " forms" + }, + { + "id": 708, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 671, + "logprob": -1.703125, + "special": false, + "text": " an" + }, + { + "id": 8727, + "logprob": 0.0, + "special": false, + "text": " essential" + }, + { + "id": 1702, + "logprob": 0.0, + "special": false, + "text": " part" + }, + { + "id": 576, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 573, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 11859, + "logprob": -1.6953125, + "special": false, + "text": " lab" + }, + { + "id": 2185, + "logprob": -1.3125, + "special": false, + "text": " process" + }, + { + "id": 578, + "logprob": -1.5, + "special": false, + "text": " and" + } + ], + "top_tokens": null + }, + "generated_text": "Test request forms are an essential part of the lab process and" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json new file mode 100644 index 00000000..e69ee25d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json new file mode 100644 index 00000000..760ebf94 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.640625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4296875, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8632812, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1328125, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76660156, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3837891, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9746094, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4189453, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.34375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8852539, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json new file mode 100644 index 00000000..7a168b2e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.65625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.3671875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 604, + "logprob": -0.36938477, + "special": false, + "text": " for" + }, + { + "id": 235248, + "logprob": -1.8046875, + "special": false, + "text": " " + }, + { + "id": 235274, + "logprob": -0.46240234, + "special": false, + "text": "1" + }, + { + "id": 235284, + "logprob": -1.7460938, + "special": false, + "text": "2" + }, + { + "id": 235265, + "logprob": -1.9443359, + "special": false, + "text": "." + }, + { + "id": 235284, + "logprob": -1.4550781, + "special": false, + "text": "2" + }, + { + "id": 235308, + "logprob": -1.0205078, + "special": false, + "text": "5" + }, + { + "id": 235290, + "logprob": -1.0283203, + "special": false, + "text": "-" + }, + { + "id": 235274, + "logprob": -1.2783203, + "special": false, + "text": "1" + }, + { + "id": 235284, + "logprob": 0.0, + "special": false, + "text": "2" + } + ], + "top_tokens": null + }, + "generated_text": "Test request for 12.25-12" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json new file mode 100644 index 00000000..bcb9b378 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.359375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4277344, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4394531, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8613281, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1523438, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76220703, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3642578, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -2.0175781, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4238281, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.328125, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8881836, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4238281, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.859375, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1445312, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.7631836, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3642578, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9960938, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4179688, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3359375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8847656, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.640625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.3671875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4257812, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8789062, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1367188, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76171875, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3515625, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9873047, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4169922, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3320312, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8930664, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.359375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4179688, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4492188, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8574219, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1445312, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.7519531, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3623047, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9707031, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4267578, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3359375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.88427734, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json new file mode 100644 index 00000000..ca7393a3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1835938, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.171875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6425781, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.7314453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.68603516, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.005393982, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.31079102, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08300781, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.58984375, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.953125, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0957031, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8095703, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9375, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json new file mode 100644 index 00000000..7bd15b90 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json @@ -0,0 +1,398 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1835938, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.171875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6425781, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.7314453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.68603516, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.005672455, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3251953, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08294678, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5854492, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9423828, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0800781, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8369141, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0683594, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9711914, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1660156, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.1796875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6376953, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.72216797, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.7089844, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.0054779053, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3190918, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08319092, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5839844, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9506836, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0878906, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8496094, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9370117, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1660156, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.1796875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6376953, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.72216797, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.7089844, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.0054779053, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3190918, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08319092, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5839844, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9506836, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0878906, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8496094, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9370117, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1660156, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.1796875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6376953, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.72216797, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.7089844, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.0054779053, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3190918, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08319092, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5839844, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9506836, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0878906, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8496094, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9370117, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar.json b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar.json new file mode 100644 index 00000000..0e87f59e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -13.90625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.328125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0566406, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.5253906, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.7578125, + "special": false, + "text": "I" + }, + { + "id": 4966, + "logprob": -1.9033203, + "special": false, + "text": " hope" + }, + { + "id": 445, + "logprob": -0.5019531, + "special": false, + "text": " this" + }, + { + "id": 6911, + "logprob": -0.21264648, + "special": false, + "text": " helps" + }, + { + "id": 29991, + "logprob": -0.5991211, + "special": false, + "text": "!" + }, + { + "id": 2803, + "logprob": -0.37475586, + "special": false, + "text": " Let" + }, + { + "id": 592, + "logprob": -0.018463135, + "special": false, + "text": " me" + }, + { + "id": 1073, + "logprob": -0.0008597374, + "special": false, + "text": " know" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI hope this helps! Let me know" +} diff --git a/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_json.json new file mode 100644 index 00000000..d7fb620d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_json.json @@ -0,0 +1,274 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 30, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 5235, + "logprob": -10.0625, + "text": "info" + }, + { + "id": 29901, + "logprob": -3.2324219, + "text": ":" + }, + { + "id": 13260, + "logprob": -10.625, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.08276367, + "text": "id" + }, + { + "id": 8753, + "logprob": -7.5273438, + "text": "hol" + }, + { + "id": 17559, + "logprob": -3.8476562, + "text": "tz" + }, + { + "id": 763, + "logprob": -10.140625, + "text": "like" + }, + { + "id": 10697, + "logprob": -10.1953125, + "text": "trees" + }, + { + "id": 322, + "logprob": -2.5742188, + "text": "and" + }, + { + "id": 756, + "logprob": -7.4882812, + "text": "has" + }, + { + "id": 1023, + "logprob": -5.0507812, + "text": "two" + }, + { + "id": 274, + "logprob": -5.3164062, + "text": "c" + }, + { + "id": 1446, + "logprob": -0.6694336, + "text": "ats" + }, + { + "id": 29889, + "logprob": -0.9995117, + "text": "." + }, + { + "id": 29871, + "logprob": -4.2421875, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 6377, + "logprob": -0.14916992, + "special": false, + "text": "{\"" + }, + { + "id": 29888, + "logprob": -0.13598633, + "special": false, + "text": "f" + }, + { + "id": 12935, + "logprob": -0.017669678, + "special": false, + "text": "irs" + }, + { + "id": 29873, + "logprob": -0.00085639954, + "special": false, + "text": "t" + }, + { + "id": 1170, + "logprob": -0.0054016113, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.13549805, + "special": false, + "text": "\":\"" + }, + { + "id": 19504, + "logprob": -0.8852539, + "special": false, + "text": "David" + }, + { + "id": 3284, + "logprob": -0.16394043, + "special": false, + "text": "\",\"" + }, + { + "id": 29882, + "logprob": -0.08862305, + "special": false, + "text": "h" + }, + { + "id": 711, + "logprob": -0.66259766, + "special": false, + "text": "ob" + }, + { + "id": 1609, + "logprob": -5.51939e-05, + "special": false, + "text": "by" + }, + { + "id": 4710, + "logprob": -0.23120117, + "special": false, + "text": "\":\"" + }, + { + "id": 29911, + "logprob": -2.3730469, + "special": false, + "text": "T" + }, + { + "id": 11003, + "logprob": -0.032104492, + "special": false, + "text": "rees" + }, + { + "id": 3284, + "logprob": -0.22021484, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.06726074, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.003501892, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0045661926, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.12512207, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.009552002, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.00042438507, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.11651611, + "special": false, + "text": "\",\"" + }, + { + "id": 29876, + "logprob": -0.29736328, + "special": false, + "text": "n" + }, + { + "id": 398, + "logprob": -0.003030777, + "special": false, + "text": "um" + }, + { + "id": 29907, + "logprob": -0.3774414, + "special": false, + "text": "C" + }, + { + "id": 1446, + "logprob": -0.0003130436, + "special": false, + "text": "ats" + }, + { + "id": 1115, + "logprob": -0.0021514893, + "special": false, + "text": "\":" + }, + { + "id": 29906, + "logprob": -0.071899414, + "special": false, + "text": "2" + }, + { + "id": 29913, + "logprob": -0.018997192, + "special": false, + "text": "}" + }, + { + "id": 2, + "logprob": 0.0, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}" +} diff --git a/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_load.json b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_load.json new file mode 100644 index 00000000..411f3947 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_load.json @@ -0,0 +1,478 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33740234, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.03125, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04244995, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4863281, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32714844, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33666992, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.01008606, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64160156, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.5, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46557617, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5341797, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022907257, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33740234, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33740234, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_regex.json b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_regex.json new file mode 100644 index 00000000..1ba9ae1e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_regex.json @@ -0,0 +1,109 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 806, + "logprob": -11.890625, + "text": "Wh" + }, + { + "id": 1446, + "logprob": -3.6699219, + "text": "ats" + }, + { + "id": 2921, + "logprob": -7.8203125, + "text": "Go" + }, + { + "id": 468, + "logprob": -8.0703125, + "text": "og" + }, + { + "id": 793, + "logprob": -2.1875, + "text": "les" + }, + { + "id": 16332, + "logprob": -9.7109375, + "text": "DNS" + } + ], + "seed": null, + "tokens": [ + { + "id": 29946, + "logprob": -1.4765625, + "special": false, + "text": "4" + }, + { + "id": 29906, + "logprob": -0.9199219, + "special": false, + "text": "2" + }, + { + "id": 29889, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -1.1367188, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -1.4648438, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.40722656, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -0.17419434, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.20251465, + "special": false, + "text": "1" + }, + { + "id": 29900, + "logprob": -1.5527344, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": -1.3710938, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": "42.1.1.101" +} diff --git a/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_single_load_instance.json b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_single_load_instance.json new file mode 100644 index 00000000..7ffb17cb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_single_load_instance.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33666992, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.009979248, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.53759766, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.0008878708, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.002275467, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json new file mode 100644 index 00000000..a7f7d2f0 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -8.6875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.546875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -1.5351562, + "special": false, + "text": " for" + }, + { + "id": 847, + "logprob": -2.5722656, + "special": false, + "text": " /" + }, + { + "id": 2754, + "logprob": -2.2714844, + "special": false, + "text": "api" + }, + { + "id": 29914, + "logprob": -0.03414917, + "special": false, + "text": "/" + }, + { + "id": 29894, + "logprob": -0.95996094, + "special": false, + "text": "v" + }, + { + "id": 29896, + "logprob": -0.3635254, + "special": false, + "text": "1" + }, + { + "id": 29914, + "logprob": -0.013031006, + "special": false, + "text": "/" + }, + { + "id": 16418, + "logprob": -3.1523438, + "special": false, + "text": "projects" + }, + { + "id": 29914, + "logprob": -0.43701172, + "special": false, + "text": "/" + }, + { + "id": 29896, + "logprob": -1.9394531, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": " for /api/v1/projects/1" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json new file mode 100644 index 00000000..9f145377 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json @@ -0,0 +1,59 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "stop_sequence", + "generated_tokens": 5, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -8.6875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.546875, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -2.5839844, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.44970703, + "special": false, + "text": ":" + }, + { + "id": 4829, + "logprob": -1.8339844, + "special": false, + "text": " Error" + }, + { + "id": 297, + "logprob": -1.0556641, + "special": false, + "text": " in" + }, + { + "id": 1243, + "logprob": 0.0, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "Test request failed: Error in test" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json new file mode 100644 index 00000000..3543dad2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -8.6875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.546875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -1.5351562, + "special": false, + "text": " for" + }, + { + "id": 847, + "logprob": -2.5566406, + "special": false, + "text": " /" + }, + { + "id": 2754, + "logprob": -2.2519531, + "special": false, + "text": "api" + }, + { + "id": 29914, + "logprob": -0.03414917, + "special": false, + "text": "/" + }, + { + "id": 29894, + "logprob": -0.96240234, + "special": false, + "text": "v" + }, + { + "id": 29896, + "logprob": -0.3647461, + "special": false, + "text": "1" + }, + { + "id": 29914, + "logprob": -0.012901306, + "special": false, + "text": "/" + }, + { + "id": 16418, + "logprob": -3.1542969, + "special": false, + "text": "projects" + }, + { + "id": 29914, + "logprob": -0.4362793, + "special": false, + "text": "/" + }, + { + "id": 29896, + "logprob": -1.9394531, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": " for /api/v1/projects/1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -8.6875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.546875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -1.5332031, + "special": false, + "text": " for" + }, + { + "id": 847, + "logprob": -2.5625, + "special": false, + "text": " /" + }, + { + "id": 2754, + "logprob": -2.2617188, + "special": false, + "text": "api" + }, + { + "id": 29914, + "logprob": -0.033996582, + "special": false, + "text": "/" + }, + { + "id": 29894, + "logprob": -0.9609375, + "special": false, + "text": "v" + }, + { + "id": 29896, + "logprob": -0.36572266, + "special": false, + "text": "1" + }, + { + "id": 29914, + "logprob": -0.0129776, + "special": false, + "text": "/" + }, + { + "id": 16418, + "logprob": -3.15625, + "special": false, + "text": "projects" + }, + { + "id": 29914, + "logprob": -0.4362793, + "special": false, + "text": "/" + }, + { + "id": 29896, + "logprob": -1.9394531, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": " for /api/v1/projects/1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -8.6875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.546875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -1.5332031, + "special": false, + "text": " for" + }, + { + "id": 847, + "logprob": -2.5625, + "special": false, + "text": " /" + }, + { + "id": 2754, + "logprob": -2.2617188, + "special": false, + "text": "api" + }, + { + "id": 29914, + "logprob": -0.033996582, + "special": false, + "text": "/" + }, + { + "id": 29894, + "logprob": -0.9609375, + "special": false, + "text": "v" + }, + { + "id": 29896, + "logprob": -0.36572266, + "special": false, + "text": "1" + }, + { + "id": 29914, + "logprob": -0.0129776, + "special": false, + "text": "/" + }, + { + "id": 16418, + "logprob": -3.15625, + "special": false, + "text": "projects" + }, + { + "id": 29914, + "logprob": -0.4362793, + "special": false, + "text": "/" + }, + { + "id": 29896, + "logprob": -1.9394531, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": " for /api/v1/projects/1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -8.6875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.546875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -1.5332031, + "special": false, + "text": " for" + }, + { + "id": 847, + "logprob": -2.5625, + "special": false, + "text": " /" + }, + { + "id": 2754, + "logprob": -2.2617188, + "special": false, + "text": "api" + }, + { + "id": 29914, + "logprob": -0.033996582, + "special": false, + "text": "/" + }, + { + "id": 29894, + "logprob": -0.9609375, + "special": false, + "text": "v" + }, + { + "id": 29896, + "logprob": -0.36572266, + "special": false, + "text": "1" + }, + { + "id": 29914, + "logprob": -0.0129776, + "special": false, + "text": "/" + }, + { + "id": 16418, + "logprob": -3.15625, + "special": false, + "text": "projects" + }, + { + "id": 29914, + "logprob": -0.4362793, + "special": false, + "text": "/" + }, + { + "id": 29896, + "logprob": -1.9394531, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": " for /api/v1/projects/1" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json new file mode 100644 index 00000000..f6e4bb90 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.4375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9316406, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.5136719, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.7783203, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2314453, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -2.0019531, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.5009766, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.057434082, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4912109, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2636719, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.4042969, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json new file mode 100644 index 00000000..6b38e709 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.453125, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -1.9980469, + "special": false, + "text": "." + }, + { + "id": 578, + "logprob": -0.15795898, + "special": false, + "text": " The" + }, + { + "id": 3622, + "logprob": -1.0458984, + "special": false, + "text": " server" + }, + { + "id": 31680, + "logprob": -1.3623047, + "special": false, + "text": " responds" + }, + { + "id": 449, + "logprob": 0.0, + "special": false, + "text": " with" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 330, + "logprob": -0.5678711, + "special": false, + "text": " \"" + }, + { + "id": 1049, + "logprob": -0.12322998, + "special": false, + "text": "200" + }, + { + "id": 10619, + "logprob": 0.0, + "special": false, + "text": " OK" + }, + { + "id": 1, + "logprob": 0.0, + "special": false, + "text": "\"" + } + ], + "top_tokens": null + }, + "generated_text": "Test request. The server responds with a \"200 OK\"" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json new file mode 100644 index 00000000..ed369a87 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.453125, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9785156, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.4941406, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.79345703, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2324219, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.9794922, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4892578, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.058258057, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4892578, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2783203, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3945312, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.40625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9433594, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.4726562, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.8022461, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2509766, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.984375, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4677734, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.059173584, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4990234, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2822266, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3867188, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.421875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9511719, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.46875, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.77490234, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2558594, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.984375, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4990234, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.059143066, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.4941406, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2578125, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.3964844, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2323, + "logprob": null, + "text": "Test" + }, + { + "id": 1715, + "logprob": -11.4140625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.9101562, + "special": false, + "text": ":" + }, + { + "id": 330, + "logprob": -3.5039062, + "special": false, + "text": " \"" + }, + { + "id": 489, + "logprob": -0.8076172, + "special": false, + "text": " +" + }, + { + "id": 1715, + "logprob": -1.2236328, + "special": false, + "text": " request" + }, + { + "id": 489, + "logprob": -1.9853516, + "special": false, + "text": " +" + }, + { + "id": 2990, + "logprob": -1.4892578, + "special": false, + "text": " \"\\" + }, + { + "id": 77, + "logprob": -0.056671143, + "special": false, + "text": "n" + }, + { + "id": 702, + "logprob": -1.5107422, + "special": false, + "text": "\"\n" + }, + { + "id": 262, + "logprob": -1.2597656, + "special": false, + "text": " " + }, + { + "id": 557, + "logprob": -2.4042969, + "special": false, + "text": " }\n\n" + } + ], + "top_tokens": null + }, + "generated_text": ": \" + request + \"\\n\"\n }\n\n" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json new file mode 100644 index 00000000..7797cc6c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.7890625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3359375, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8779297, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2744141, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.6933594, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4648438, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15600586, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.8027344, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.23022461, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.0069885254, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.02218628, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json new file mode 100644 index 00000000..fa2fd4a2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.84375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6015625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 29899, + "logprob": -1.5625, + "special": false, + "text": "-" + }, + { + "id": 1454, + "logprob": -0.20410156, + "special": false, + "text": "for" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 9342, + "logprob": 0.0, + "special": false, + "text": "comment" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 396, + "logprob": -0.27685547, + "special": false, + "text": " #" + }, + { + "id": 29906, + "logprob": -0.4970703, + "special": false, + "text": "2" + }, + { + "id": 29900, + "logprob": -0.80615234, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": 0.0, + "special": false, + "text": "1" + }, + { + "id": 29955, + "logprob": -1.0751953, + "special": false, + "text": "7" + } + ], + "top_tokens": null + }, + "generated_text": "Test request-for-comment: #2017" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json new file mode 100644 index 00000000..594b7351 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.828125, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.609375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3300781, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8740234, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2646484, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7158203, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4667969, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15344238, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.81591797, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22973633, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007045746, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021957397, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.84375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.59375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3378906, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8779297, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2636719, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.6992188, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4589844, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15344238, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.79052734, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22937012, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007041931, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.022140503, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.84375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.609375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3261719, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8730469, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2587891, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.6894531, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.46875, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.1541748, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.80322266, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22912598, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.0070495605, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021606445, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.84375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6015625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3320312, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.875, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2646484, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.6884766, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4589844, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15185547, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.79833984, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22827148, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.006996155, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021560669, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json new file mode 100644 index 00000000..d8a298eb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -0,0 +1,98 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 338, + "logprob": -10.0078125, + "text": "is" + }, + { + "id": 21784, + "logprob": -15.515625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -2.8847656, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -4.140625, + "text": "?" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -1.1582031, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.23083496, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": 0.0, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": 0.0, + "special": false, + "text": " learning" + }, + { + "id": 29892, + "logprob": -0.61816406, + "special": false, + "text": "," + }, + { + "id": 607, + "logprob": -0.7089844, + "special": false, + "text": " which" + }, + { + "id": 508, + "logprob": -1.7724609, + "special": false, + "text": " can" + }, + { + "id": 367, + "logprob": 0.0, + "special": false, + "text": " be" + }, + { + "id": 5545, + "logprob": 0.0, + "special": false, + "text": " considered" + }, + { + "id": 408, + "logprob": -0.3869629, + "special": false, + "text": " as" + } + ] + }, + "generated_text": "What is Deep Learning?\nDeep learning, which can be considered as" +} diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json new file mode 100644 index 00000000..413af1d7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json @@ -0,0 +1,414 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2753906, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.48046875, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1845703, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.5727539, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.00010967255, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.04510498, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.00020992756, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.0046539307, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025844574, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1826172, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1826172, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1826172, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json new file mode 100644 index 00000000..15754b14 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json @@ -0,0 +1,103 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2753906, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.48046875, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1845703, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.5727539, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108122826, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.01852417, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004787445, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00026226044, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral.json b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral.json new file mode 100644 index 00000000..4e7de9a6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.54785156, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4091797, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94433594, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.81347656, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2958984, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0644531, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9580078, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5073242, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1816406, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json new file mode 100644 index 00000000..c0dc6471 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 28747, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -0.1307373, + "special": false, + "text": " Let" + }, + { + "id": 332, + "logprob": -2.3359375, + "special": false, + "text": " u" + }, + { + "id": 347, + "logprob": 0.0, + "special": false, + "text": " be" + }, + { + "id": 325, + "logprob": -1.0234375, + "special": false, + "text": " (" + }, + { + "id": 28734, + "logprob": -2.0292969, + "special": false, + "text": "0" + }, + { + "id": 648, + "logprob": -1.0439453, + "special": false, + "text": " +" + }, + { + "id": 28705, + "logprob": -0.24499512, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.5073242, + "special": false, + "text": "3" + }, + { + "id": 387, + "logprob": -1.5507812, + "special": false, + "text": " -" + } + ], + "top_tokens": null + }, + "generated_text": "Test request: Let u be (0 + 3 -" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_load.json b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_load.json new file mode 100644 index 00000000..9d133077 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.55078125, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4140625, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94140625, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8173828, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2978516, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0664062, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9560547, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1787109, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.54785156, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4111328, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0292969, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94433594, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8178711, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2939453, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0644531, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9550781, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1796875, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.55078125, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4140625, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94140625, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8173828, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2978516, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0664062, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9560547, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1787109, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.55078125, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4140625, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94140625, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8173828, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2978516, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0664062, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9560547, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1787109, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox.json b/integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox.json new file mode 100644 index 00000000..66ddbaef --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox.json @@ -0,0 +1,113 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.234375, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.1054688, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.953125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.0820312, + "text": " today" + }, + { + "id": 32, + "logprob": -0.15148926, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.27026367, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.88378906, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.94921875, + "special": false, + "text": "'m" + }, + { + "id": 417, + "logprob": -2.2402344, + "special": false, + "text": " not" + }, + { + "id": 2119, + "logprob": -0.3725586, + "special": false, + "text": " sure" + }, + { + "id": 13, + "logprob": -1.078125, + "special": false, + "text": "," + }, + { + "id": 534, + "logprob": -0.67822266, + "special": false, + "text": " which" + }, + { + "id": 310, + "logprob": -1.3837891, + "special": false, + "text": " is" + }, + { + "id": 253, + "logprob": -1.7050781, + "special": false, + "text": " the" + }, + { + "id": 1682, + "logprob": -0.052001953, + "special": false, + "text": " best" + }, + { + "id": 1039, + "logprob": -2.0390625, + "special": false, + "text": " way" + } + ] + }, + "generated_text": "I'm not sure, which is the best way" +} diff --git a/integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox_load.json b/integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox_load.json new file mode 100644 index 00000000..5ef6b3a2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox_load.json @@ -0,0 +1,454 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.234375, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.21875, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.9375, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.1015625, + "text": " today" + }, + { + "id": 32, + "logprob": -0.15319824, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.2614746, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8886719, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.98046875, + "special": false, + "text": "'m" + }, + { + "id": 417, + "logprob": -2.2265625, + "special": false, + "text": " not" + }, + { + "id": 2119, + "logprob": -0.3479004, + "special": false, + "text": " sure" + }, + { + "id": 13, + "logprob": -1.0117188, + "special": false, + "text": "," + }, + { + "id": 534, + "logprob": -0.67871094, + "special": false, + "text": " which" + }, + { + "id": 310, + "logprob": -1.421875, + "special": false, + "text": " is" + }, + { + "id": 253, + "logprob": -1.7382812, + "special": false, + "text": " the" + }, + { + "id": 1682, + "logprob": -0.051330566, + "special": false, + "text": " best" + }, + { + "id": 1039, + "logprob": -2.0390625, + "special": false, + "text": " way" + } + ] + }, + "generated_text": "I'm not sure, which is the best way" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.234375, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.1054688, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.953125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.0820312, + "text": " today" + }, + { + "id": 32, + "logprob": -0.15148926, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.27026367, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.88378906, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9819336, + "special": false, + "text": "'m" + }, + { + "id": 417, + "logprob": -2.2421875, + "special": false, + "text": " not" + }, + { + "id": 2119, + "logprob": -0.3474121, + "special": false, + "text": " sure" + }, + { + "id": 13, + "logprob": -1.078125, + "special": false, + "text": "," + }, + { + "id": 534, + "logprob": -0.69140625, + "special": false, + "text": " which" + }, + { + "id": 310, + "logprob": -1.4072266, + "special": false, + "text": " is" + }, + { + "id": 253, + "logprob": -1.7041016, + "special": false, + "text": " the" + }, + { + "id": 1682, + "logprob": -0.053375244, + "special": false, + "text": " best" + }, + { + "id": 1039, + "logprob": -2.0351562, + "special": false, + "text": " way" + } + ] + }, + "generated_text": "I'm not sure, which is the best way" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.234375, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.21875, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.9375, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.1015625, + "text": " today" + }, + { + "id": 32, + "logprob": -0.15319824, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.2614746, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8886719, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.98046875, + "special": false, + "text": "'m" + }, + { + "id": 417, + "logprob": -2.2265625, + "special": false, + "text": " not" + }, + { + "id": 2119, + "logprob": -0.3479004, + "special": false, + "text": " sure" + }, + { + "id": 13, + "logprob": -1.0117188, + "special": false, + "text": "," + }, + { + "id": 534, + "logprob": -0.67871094, + "special": false, + "text": " which" + }, + { + "id": 310, + "logprob": -1.421875, + "special": false, + "text": " is" + }, + { + "id": 253, + "logprob": -1.7382812, + "special": false, + "text": " the" + }, + { + "id": 1682, + "logprob": -0.051330566, + "special": false, + "text": " best" + }, + { + "id": 1039, + "logprob": -2.0390625, + "special": false, + "text": " way" + } + ] + }, + "generated_text": "I'm not sure, which is the best way" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.234375, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.21875, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.9375, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.1015625, + "text": " today" + }, + { + "id": 32, + "logprob": -0.15319824, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.2614746, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8886719, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.98046875, + "special": false, + "text": "'m" + }, + { + "id": 417, + "logprob": -2.2265625, + "special": false, + "text": " not" + }, + { + "id": 2119, + "logprob": -0.3479004, + "special": false, + "text": " sure" + }, + { + "id": 13, + "logprob": -1.0117188, + "special": false, + "text": "," + }, + { + "id": 534, + "logprob": -0.67871094, + "special": false, + "text": " which" + }, + { + "id": 310, + "logprob": -1.421875, + "special": false, + "text": " is" + }, + { + "id": 253, + "logprob": -1.7382812, + "special": false, + "text": " the" + }, + { + "id": 1682, + "logprob": -0.051330566, + "special": false, + "text": " best" + }, + { + "id": 1039, + "logprob": -2.0390625, + "special": false, + "text": " way" + } + ] + }, + "generated_text": "I'm not sure, which is the best way" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox.json b/integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox.json new file mode 100644 index 00000000..787704ce --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox.json @@ -0,0 +1,163 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.03125, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1601562, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.4609375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005657196, + "text": "e" + }, + { + "id": 13, + "logprob": -7.28125, + "text": "," + }, + { + "id": 285, + "logprob": -0.2980957, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1679688, + "text": " what" + }, + { + "id": 434, + "logprob": -5.6210938, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.81103516, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6640625, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.265625, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.5078125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1582031, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008720398, + "text": "?" + }, + { + "id": 0, + "logprob": -2.4726562, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.265625, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.63183594, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5390625, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.045684814, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.002090454, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.3589859e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0009455681, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.088012695, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12585449, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017196655, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.49731445, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" +} diff --git a/integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox_load.json b/integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox_load.json new file mode 100644 index 00000000..47d6a77e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox_load.json @@ -0,0 +1,654 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.03125, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1601562, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.4609375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005657196, + "text": "e" + }, + { + "id": 13, + "logprob": -7.28125, + "text": "," + }, + { + "id": 285, + "logprob": -0.2980957, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1679688, + "text": " what" + }, + { + "id": 434, + "logprob": -5.6210938, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.81103516, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6640625, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.265625, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.5078125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1582031, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008720398, + "text": "?" + }, + { + "id": 0, + "logprob": -2.4726562, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.265625, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.63183594, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5488281, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.045684814, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.00207901, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.335144e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00097227097, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.0892334, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12463379, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.50341797, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.03125, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1601562, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.4609375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005657196, + "text": "e" + }, + { + "id": 13, + "logprob": -7.28125, + "text": "," + }, + { + "id": 285, + "logprob": -0.2980957, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1679688, + "text": " what" + }, + { + "id": 434, + "logprob": -5.6210938, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.81103516, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6640625, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.265625, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.5078125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1582031, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008720398, + "text": "?" + }, + { + "id": 0, + "logprob": -2.4726562, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.265625, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.63183594, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5488281, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.045684814, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.00207901, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.335144e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00097227097, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.0892334, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12463379, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.50341797, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.03125, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1601562, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.4609375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005657196, + "text": "e" + }, + { + "id": 13, + "logprob": -7.28125, + "text": "," + }, + { + "id": 285, + "logprob": -0.2980957, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1679688, + "text": " what" + }, + { + "id": 434, + "logprob": -5.6210938, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.81103516, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6640625, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.265625, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.5078125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1582031, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008720398, + "text": "?" + }, + { + "id": 0, + "logprob": -2.4726562, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.265625, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.63183594, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5488281, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.045684814, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.00207901, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.335144e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00097227097, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.0892334, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12463379, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.50341797, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.03125, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1601562, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.4609375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005657196, + "text": "e" + }, + { + "id": 13, + "logprob": -7.28125, + "text": "," + }, + { + "id": 285, + "logprob": -0.2980957, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1679688, + "text": " what" + }, + { + "id": 434, + "logprob": -5.6210938, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.81103516, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6640625, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.265625, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.5078125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1582031, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008720398, + "text": "?" + }, + { + "id": 0, + "logprob": -2.4726562, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.265625, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.63183594, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5488281, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.045684814, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.00207901, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.335144e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00097227097, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.0892334, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12463379, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.50341797, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json new file mode 100644 index 00000000..037e0b16 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json @@ -0,0 +1,25 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 2, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 54901, + "logprob": -0.72753906, + "special": false, + "text": "beach" + }, + { + "id": 1, + "logprob": -0.011009216, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "beach" +} diff --git a/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json new file mode 100644 index 00000000..51d969b2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.76660156, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7246094, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.41333008, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.11785889, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.97265625, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.0569458, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" +} diff --git a/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json new file mode 100644 index 00000000..221ff13d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json @@ -0,0 +1,60 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "stop_sequence", + "generated_tokens": 6, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 284, + "logprob": -0.19421387, + "special": false, + "text": " to" + }, + { + "id": 3758, + "logprob": -0.62597656, + "special": false, + "text": " send" + }, + { + "id": 1366, + "logprob": -0.87060547, + "special": false, + "text": " data" + }, + { + "id": 625, + "logprob": -0.88427734, + "special": false, + "text": " over" + }, + { + "id": 257, + "logprob": -1.0830078, + "special": false, + "text": " a" + }, + { + "id": 3127, + "logprob": -1.9462891, + "special": false, + "text": " network" + } + ], + "top_tokens": null + }, + "generated_text": "Test request to send data over a network" +} diff --git a/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json new file mode 100644 index 00000000..62f7fd32 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json new file mode 100644 index 00000000..7219f9e6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9160156, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1035156, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.1025391, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1953125, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3203125, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13537598, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2402344, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json new file mode 100644 index 00000000..4a2936af --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 311, + "logprob": -1.4277344, + "special": false, + "text": " to" + }, + { + "id": 279, + "logprob": -0.65478516, + "special": false, + "text": " the" + }, + { + "id": 2473, + "logprob": -1.8300781, + "special": false, + "text": " service" + }, + { + "id": 382, + "logprob": -0.75, + "special": false, + "text": ".\n\n" + }, + { + "id": 286, + "logprob": -0.11621094, + "special": false, + "text": " " + }, + { + "id": 549, + "logprob": 0.0, + "special": false, + "text": " :" + }, + { + "id": 689, + "logprob": -0.48608398, + "special": false, + "text": "return" + }, + { + "id": 25, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 5949, + "logprob": -0.5756836, + "special": false, + "text": " Response" + }, + { + "id": 504, + "logprob": -0.24499512, + "special": false, + "text": " from" + } + ], + "top_tokens": null + }, + "generated_text": "Test request to the service.\n\n :return: Response from" +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json new file mode 100644 index 00000000..4786ff24 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder.json b/integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder.json new file mode 100644 index 00000000..0293e35a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder.json @@ -0,0 +1,93 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 563, + "logprob": null, + "text": "def" + }, + { + "id": 942, + "logprob": -5.1367188, + "text": " print" + }, + { + "id": 62, + "logprob": -0.24450684, + "text": "_" + }, + { + "id": 7196, + "logprob": -6.9609375, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 1241, + "logprob": -0.9863281, + "special": false, + "text": "():" + }, + { + "id": 258, + "logprob": -0.21447754, + "special": false, + "text": "\n " + }, + { + "id": 942, + "logprob": -0.43701172, + "special": false, + "text": " print" + }, + { + "id": 372, + "logprob": -0.5361328, + "special": false, + "text": "(\"" + }, + { + "id": 7371, + "logprob": -0.44555664, + "special": false, + "text": "Hello" + }, + { + "id": 9956, + "logprob": -1.2412109, + "special": false, + "text": " World" + }, + { + "id": 8657, + "logprob": -0.7583008, + "special": false, + "text": "!\")" + }, + { + "id": 185, + "logprob": -0.76171875, + "special": false, + "text": "\n" + }, + { + "id": 185, + "logprob": -0.20837402, + "special": false, + "text": "\n" + }, + { + "id": 1018, + "logprob": -1.2470703, + "special": false, + "text": "print" + } + ] + }, + "generated_text": "():\n print(\"Hello World!\")\n\nprint" +} diff --git a/integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder_load.json b/integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder_load.json new file mode 100644 index 00000000..a03580b3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder_load.json @@ -0,0 +1,374 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 563, + "logprob": null, + "text": "def" + }, + { + "id": 942, + "logprob": -5.1367188, + "text": " print" + }, + { + "id": 62, + "logprob": -0.24450684, + "text": "_" + }, + { + "id": 7196, + "logprob": -6.9609375, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 1241, + "logprob": -0.9863281, + "special": false, + "text": "():" + }, + { + "id": 258, + "logprob": -0.21362305, + "special": false, + "text": "\n " + }, + { + "id": 942, + "logprob": -0.44360352, + "special": false, + "text": " print" + }, + { + "id": 372, + "logprob": -0.54248047, + "special": false, + "text": "(\"" + }, + { + "id": 7371, + "logprob": -0.44555664, + "special": false, + "text": "Hello" + }, + { + "id": 9956, + "logprob": -1.2441406, + "special": false, + "text": " World" + }, + { + "id": 8657, + "logprob": -0.75878906, + "special": false, + "text": "!\")" + }, + { + "id": 185, + "logprob": -0.76171875, + "special": false, + "text": "\n" + }, + { + "id": 185, + "logprob": -0.2084961, + "special": false, + "text": "\n" + }, + { + "id": 1018, + "logprob": -1.2460938, + "special": false, + "text": "print" + } + ] + }, + "generated_text": "():\n print(\"Hello World!\")\n\nprint" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 563, + "logprob": null, + "text": "def" + }, + { + "id": 942, + "logprob": -5.1367188, + "text": " print" + }, + { + "id": 62, + "logprob": -0.24450684, + "text": "_" + }, + { + "id": 7196, + "logprob": -6.9609375, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 1241, + "logprob": -0.9863281, + "special": false, + "text": "():" + }, + { + "id": 258, + "logprob": -0.21362305, + "special": false, + "text": "\n " + }, + { + "id": 942, + "logprob": -0.44360352, + "special": false, + "text": " print" + }, + { + "id": 372, + "logprob": -0.54248047, + "special": false, + "text": "(\"" + }, + { + "id": 7371, + "logprob": -0.44555664, + "special": false, + "text": "Hello" + }, + { + "id": 9956, + "logprob": -1.2441406, + "special": false, + "text": " World" + }, + { + "id": 8657, + "logprob": -0.75878906, + "special": false, + "text": "!\")" + }, + { + "id": 185, + "logprob": -0.76171875, + "special": false, + "text": "\n" + }, + { + "id": 185, + "logprob": -0.2084961, + "special": false, + "text": "\n" + }, + { + "id": 1018, + "logprob": -1.2460938, + "special": false, + "text": "print" + } + ] + }, + "generated_text": "():\n print(\"Hello World!\")\n\nprint" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 563, + "logprob": null, + "text": "def" + }, + { + "id": 942, + "logprob": -5.1367188, + "text": " print" + }, + { + "id": 62, + "logprob": -0.24450684, + "text": "_" + }, + { + "id": 7196, + "logprob": -6.9609375, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 1241, + "logprob": -0.9863281, + "special": false, + "text": "():" + }, + { + "id": 258, + "logprob": -0.21362305, + "special": false, + "text": "\n " + }, + { + "id": 942, + "logprob": -0.44360352, + "special": false, + "text": " print" + }, + { + "id": 372, + "logprob": -0.54248047, + "special": false, + "text": "(\"" + }, + { + "id": 7371, + "logprob": -0.44555664, + "special": false, + "text": "Hello" + }, + { + "id": 9956, + "logprob": -1.2441406, + "special": false, + "text": " World" + }, + { + "id": 8657, + "logprob": -0.75878906, + "special": false, + "text": "!\")" + }, + { + "id": 185, + "logprob": -0.76171875, + "special": false, + "text": "\n" + }, + { + "id": 185, + "logprob": -0.2084961, + "special": false, + "text": "\n" + }, + { + "id": 1018, + "logprob": -1.2460938, + "special": false, + "text": "print" + } + ] + }, + "generated_text": "():\n print(\"Hello World!\")\n\nprint" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 563, + "logprob": null, + "text": "def" + }, + { + "id": 942, + "logprob": -5.1367188, + "text": " print" + }, + { + "id": 62, + "logprob": -0.24450684, + "text": "_" + }, + { + "id": 7196, + "logprob": -6.9609375, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 1241, + "logprob": -0.9863281, + "special": false, + "text": "():" + }, + { + "id": 258, + "logprob": -0.21362305, + "special": false, + "text": "\n " + }, + { + "id": 942, + "logprob": -0.44360352, + "special": false, + "text": " print" + }, + { + "id": 372, + "logprob": -0.54248047, + "special": false, + "text": "(\"" + }, + { + "id": 7371, + "logprob": -0.44555664, + "special": false, + "text": "Hello" + }, + { + "id": 9956, + "logprob": -1.2441406, + "special": false, + "text": " World" + }, + { + "id": 8657, + "logprob": -0.75878906, + "special": false, + "text": "!\")" + }, + { + "id": 185, + "logprob": -0.76171875, + "special": false, + "text": "\n" + }, + { + "id": 185, + "logprob": -0.2084961, + "special": false, + "text": "\n" + }, + { + "id": 1018, + "logprob": -1.2460938, + "special": false, + "text": "print" + } + ] + }, + "generated_text": "():\n print(\"Hello World!\")\n\nprint" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder.json new file mode 100644 index 00000000..8505c1db --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder.json @@ -0,0 +1,93 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2590332, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39379883, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.61376953, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.47338867, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.80810547, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7397461, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0371094, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json new file mode 100644 index 00000000..89e02c07 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -0,0 +1,393 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6328125, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6035156, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9882812, + "text": "hello" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2262, + "logprob": -0.042999268, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -0.38549805, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.5229492, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.10632324, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -0.20141602, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.16027832, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 27, + "logprob": 0.0, + "special": false, + "text": ")" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11442, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 313, + "logprob": -0.6328125, + "special": false, + "text": " \"" + }, + { + "id": 313, + "logprob": -1.7011719, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 596, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 490, + "logprob": 0.0, + "special": false, + "text": "))" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_load.json new file mode 100644 index 00000000..0b3ad554 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_load.json @@ -0,0 +1,374 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 1459, + "logprob": -5.6289062, + "text": " print" + }, + { + "id": 81, + "logprob": -1.6005859, + "text": "_" + }, + { + "id": 7656, + "logprob": -5.9921875, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2262, + "logprob": -0.7705078, + "special": false, + "text": "():" + }, + { + "id": 284, + "logprob": -0.2602539, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": -0.39282227, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.6113281, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": -0.4765625, + "special": false, + "text": "Hello" + }, + { + "id": 10896, + "logprob": -1.5068359, + "special": false, + "text": " World" + }, + { + "id": 657, + "logprob": -0.8154297, + "special": false, + "text": "\")" + }, + { + "id": 203, + "logprob": -0.7319336, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": -0.35229492, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -1.0380859, + "special": false, + "text": "def" + } + ] + }, + "generated_text": "():\n print(\"Hello World\")\n\ndef" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json new file mode 100644 index 00000000..36a2ff4d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json @@ -0,0 +1,94 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40844727, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27905273, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6118164, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68652344, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4619141, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.7993164, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.63134766, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23278809, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2294922, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json new file mode 100644 index 00000000..38117272 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -0,0 +1,394 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2284, + "logprob": -0.296875, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.28125, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -0.79248047, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.61816406, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.0619812, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -0.4091797, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": -0.21655273, + "special": false, + "text": "name" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": -0.034698486, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": -0.20141602, + "special": false, + "text": " +" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 16013, + "logprob": 0.0, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 49, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11505, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 3021, + "logprob": -0.5761719, + "special": false, + "text": " \"," + }, + { + "id": 863, + "logprob": 0.0, + "special": false, + "text": " you" + }, + { + "id": 904, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 615, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 46, + "logprob": 0.0, + "special": false, + "text": ")" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json new file mode 100644 index 00000000..9e82d4be --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json @@ -0,0 +1,378 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json new file mode 100644 index 00000000..5e537bb7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -0,0 +1,194 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -8.5859375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -7.5859375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.2668457, + "text": "_" + }, + { + "id": 6009, + "logprob": -1.6416016, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.22705078, + "text": "(" + }, + { + "id": 62, + "logprob": -5.2304688, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0976562, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.1044922, + "text": " List" + }, + { + "id": 77, + "logprob": -0.14294434, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.32299805, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.8164062, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -0.1282959, + "special": false, + "text": "\n " + }, + { + "id": 1524, + "logprob": -0.97998047, + "special": false, + "text": " \"\"\"" + }, + { + "id": 284, + "logprob": -0.7006836, + "special": false, + "text": "\n " + }, + { + "id": 14883, + "logprob": -2.1933594, + "special": false, + "text": " Calculate" + }, + { + "id": 322, + "logprob": -0.2697754, + "special": false, + "text": " the" + }, + { + "id": 3226, + "logprob": -0.0836792, + "special": false, + "text": " ge" + }, + { + "id": 21017, + "logprob": -0.018737793, + "special": false, + "text": "ometric" + }, + { + "id": 5651, + "logprob": -0.028640747, + "special": false, + "text": " mean" + }, + { + "id": 432, + "logprob": -0.29467773, + "special": false, + "text": " of" + }, + { + "id": 312, + "logprob": -0.31518555, + "special": false, + "text": " a" + }, + { + "id": 1149, + "logprob": -0.20605469, + "special": false, + "text": " list" + }, + { + "id": 432, + "logprob": -0.23254395, + "special": false, + "text": " of" + }, + { + "id": 7515, + "logprob": -0.4489746, + "special": false, + "text": " numbers" + }, + { + "id": 32, + "logprob": -0.6044922, + "special": false, + "text": "." + }, + { + "id": 446, + "logprob": -0.63964844, + "special": false, + "text": "\n\n " + }, + { + "id": 499, + "logprob": -1.1953125, + "special": false, + "text": " :" + }, + { + "id": 753, + "logprob": -0.03515625, + "special": false, + "text": "param" + }, + { + "id": 498, + "logprob": -0.06311035, + "special": false, + "text": " L" + }, + { + "id": 44, + "logprob": -0.003414154, + "special": false, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.3310547, + "special": false, + "text": " List" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json new file mode 100644 index 00000000..bf0f5146 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -0,0 +1,194 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -8.5859375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -7.5898438, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.26586914, + "text": "_" + }, + { + "id": 6009, + "logprob": -1.6347656, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.22705078, + "text": "(" + }, + { + "id": 62, + "logprob": -5.2382812, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0996094, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.1025391, + "text": " List" + }, + { + "id": 77, + "logprob": -0.14294434, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.32226562, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.8164062, + "text": "]):" + } + ], + "seed": 0, + "tokens": [ + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 442, + "logprob": -1.3134766, + "special": false, + "text": " return" + }, + { + "id": 11665, + "logprob": -0.10021973, + "special": false, + "text": " reduce" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 5962, + "logprob": 0.0, + "special": false, + "text": "lambda" + }, + { + "id": 816, + "logprob": 0.0, + "special": false, + "text": " x" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 533, + "logprob": 0.0, + "special": false, + "text": " y" + }, + { + "id": 44, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 816, + "logprob": 0.0, + "special": false, + "text": " x" + }, + { + "id": 319, + "logprob": -0.42871094, + "special": false, + "text": " *" + }, + { + "id": 533, + "logprob": 0.0, + "special": false, + "text": " y" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 498, + "logprob": 0.0, + "special": false, + "text": " L" + }, + { + "id": 27, + "logprob": 0.0, + "special": false, + "text": ")" + }, + { + "id": 1115, + "logprob": 0.0, + "special": false, + "text": " **" + }, + { + "id": 308, + "logprob": 0.0, + "special": false, + "text": " (" + }, + { + "id": 35, + "logprob": 0.0, + "special": false, + "text": "1" + }, + { + "id": 32, + "logprob": -0.31323242, + "special": false, + "text": "." + }, + { + "id": 34, + "logprob": 0.0, + "special": false, + "text": "0" + } + ], + "top_tokens": null + }, + "generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json new file mode 100644 index 00000000..46a21ed8 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -0,0 +1,538 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -8.5859375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -7.5820312, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.26708984, + "text": "_" + }, + { + "id": 6009, + "logprob": -1.6386719, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.22717285, + "text": "(" + }, + { + "id": 62, + "logprob": -5.234375, + "text": "L" + }, + { + "id": 44, + "logprob": -3.1015625, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.1083984, + "text": " List" + }, + { + "id": 77, + "logprob": -0.14294434, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.32592773, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.8164062, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -0.12817383, + "special": false, + "text": "\n " + }, + { + "id": 1524, + "logprob": -0.9863281, + "special": false, + "text": " \"\"\"" + }, + { + "id": 284, + "logprob": -0.7011719, + "special": false, + "text": "\n " + }, + { + "id": 14883, + "logprob": -2.2050781, + "special": false, + "text": " Calculate" + }, + { + "id": 322, + "logprob": -0.2668457, + "special": false, + "text": " the" + }, + { + "id": 3226, + "logprob": -0.08465576, + "special": false, + "text": " ge" + }, + { + "id": 21017, + "logprob": -0.019012451, + "special": false, + "text": "ometric" + }, + { + "id": 5651, + "logprob": -0.028625488, + "special": false, + "text": " mean" + }, + { + "id": 432, + "logprob": -0.29418945, + "special": false, + "text": " of" + }, + { + "id": 312, + "logprob": -0.3161621, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -8.5859375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -7.59375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.26953125, + "text": "_" + }, + { + "id": 6009, + "logprob": -1.640625, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.22705078, + "text": "(" + }, + { + "id": 62, + "logprob": -5.234375, + "text": "L" + }, + { + "id": 44, + "logprob": -3.1132812, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.1123047, + "text": " List" + }, + { + "id": 77, + "logprob": -0.14294434, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.32299805, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.8164062, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -0.12854004, + "special": false, + "text": "\n " + }, + { + "id": 1524, + "logprob": -0.9897461, + "special": false, + "text": " \"\"\"" + }, + { + "id": 284, + "logprob": -0.69970703, + "special": false, + "text": "\n " + }, + { + "id": 14883, + "logprob": -2.2050781, + "special": false, + "text": " Calculate" + }, + { + "id": 322, + "logprob": -0.2668457, + "special": false, + "text": " the" + }, + { + "id": 3226, + "logprob": -0.08496094, + "special": false, + "text": " ge" + }, + { + "id": 21017, + "logprob": -0.019012451, + "special": false, + "text": "ometric" + }, + { + "id": 5651, + "logprob": -0.029037476, + "special": false, + "text": " mean" + }, + { + "id": 432, + "logprob": -0.2939453, + "special": false, + "text": " of" + }, + { + "id": 312, + "logprob": -0.31591797, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -8.5859375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -7.5859375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.26586914, + "text": "_" + }, + { + "id": 6009, + "logprob": -1.6347656, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.22766113, + "text": "(" + }, + { + "id": 62, + "logprob": -5.2265625, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0976562, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.1025391, + "text": " List" + }, + { + "id": 77, + "logprob": -0.1427002, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.32592773, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.8164062, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -0.13012695, + "special": false, + "text": "\n " + }, + { + "id": 1524, + "logprob": -0.98046875, + "special": false, + "text": " \"\"\"" + }, + { + "id": 284, + "logprob": -0.69921875, + "special": false, + "text": "\n " + }, + { + "id": 14883, + "logprob": -2.1992188, + "special": false, + "text": " Calculate" + }, + { + "id": 322, + "logprob": -0.2668457, + "special": false, + "text": " the" + }, + { + "id": 3226, + "logprob": -0.083496094, + "special": false, + "text": " ge" + }, + { + "id": 21017, + "logprob": -0.01902771, + "special": false, + "text": "ometric" + }, + { + "id": 5651, + "logprob": -0.029006958, + "special": false, + "text": " mean" + }, + { + "id": 432, + "logprob": -0.29248047, + "special": false, + "text": " of" + }, + { + "id": 312, + "logprob": -0.3161621, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -8.5859375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -7.5859375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.26904297, + "text": "_" + }, + { + "id": 6009, + "logprob": -1.6386719, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.22705078, + "text": "(" + }, + { + "id": 62, + "logprob": -5.234375, + "text": "L" + }, + { + "id": 44, + "logprob": -3.1132812, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.1074219, + "text": " List" + }, + { + "id": 77, + "logprob": -0.14477539, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.3256836, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.8027344, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -0.12915039, + "special": false, + "text": "\n " + }, + { + "id": 1524, + "logprob": -0.98535156, + "special": false, + "text": " \"\"\"" + }, + { + "id": 284, + "logprob": -0.69921875, + "special": false, + "text": "\n " + }, + { + "id": 14883, + "logprob": -2.2011719, + "special": false, + "text": " Calculate" + }, + { + "id": 322, + "logprob": -0.26708984, + "special": false, + "text": " the" + }, + { + "id": 3226, + "logprob": -0.08502197, + "special": false, + "text": " ge" + }, + { + "id": 21017, + "logprob": -0.019012451, + "special": false, + "text": "ometric" + }, + { + "id": 5651, + "logprob": -0.028625488, + "special": false, + "text": " mean" + }, + { + "id": 432, + "logprob": -0.29589844, + "special": false, + "text": " of" + }, + { + "id": 312, + "logprob": -0.31591797, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + } +] diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json new file mode 100644 index 00000000..d7fb620d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json @@ -0,0 +1,274 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 30, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 5235, + "logprob": -10.0625, + "text": "info" + }, + { + "id": 29901, + "logprob": -3.2324219, + "text": ":" + }, + { + "id": 13260, + "logprob": -10.625, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.08276367, + "text": "id" + }, + { + "id": 8753, + "logprob": -7.5273438, + "text": "hol" + }, + { + "id": 17559, + "logprob": -3.8476562, + "text": "tz" + }, + { + "id": 763, + "logprob": -10.140625, + "text": "like" + }, + { + "id": 10697, + "logprob": -10.1953125, + "text": "trees" + }, + { + "id": 322, + "logprob": -2.5742188, + "text": "and" + }, + { + "id": 756, + "logprob": -7.4882812, + "text": "has" + }, + { + "id": 1023, + "logprob": -5.0507812, + "text": "two" + }, + { + "id": 274, + "logprob": -5.3164062, + "text": "c" + }, + { + "id": 1446, + "logprob": -0.6694336, + "text": "ats" + }, + { + "id": 29889, + "logprob": -0.9995117, + "text": "." + }, + { + "id": 29871, + "logprob": -4.2421875, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 6377, + "logprob": -0.14916992, + "special": false, + "text": "{\"" + }, + { + "id": 29888, + "logprob": -0.13598633, + "special": false, + "text": "f" + }, + { + "id": 12935, + "logprob": -0.017669678, + "special": false, + "text": "irs" + }, + { + "id": 29873, + "logprob": -0.00085639954, + "special": false, + "text": "t" + }, + { + "id": 1170, + "logprob": -0.0054016113, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.13549805, + "special": false, + "text": "\":\"" + }, + { + "id": 19504, + "logprob": -0.8852539, + "special": false, + "text": "David" + }, + { + "id": 3284, + "logprob": -0.16394043, + "special": false, + "text": "\",\"" + }, + { + "id": 29882, + "logprob": -0.08862305, + "special": false, + "text": "h" + }, + { + "id": 711, + "logprob": -0.66259766, + "special": false, + "text": "ob" + }, + { + "id": 1609, + "logprob": -5.51939e-05, + "special": false, + "text": "by" + }, + { + "id": 4710, + "logprob": -0.23120117, + "special": false, + "text": "\":\"" + }, + { + "id": 29911, + "logprob": -2.3730469, + "special": false, + "text": "T" + }, + { + "id": 11003, + "logprob": -0.032104492, + "special": false, + "text": "rees" + }, + { + "id": 3284, + "logprob": -0.22021484, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.06726074, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.003501892, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0045661926, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.12512207, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.009552002, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.00042438507, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.11651611, + "special": false, + "text": "\",\"" + }, + { + "id": 29876, + "logprob": -0.29736328, + "special": false, + "text": "n" + }, + { + "id": 398, + "logprob": -0.003030777, + "special": false, + "text": "um" + }, + { + "id": 29907, + "logprob": -0.3774414, + "special": false, + "text": "C" + }, + { + "id": 1446, + "logprob": -0.0003130436, + "special": false, + "text": "ats" + }, + { + "id": 1115, + "logprob": -0.0021514893, + "special": false, + "text": "\":" + }, + { + "id": 29906, + "logprob": -0.071899414, + "special": false, + "text": "2" + }, + { + "id": 29913, + "logprob": -0.018997192, + "special": false, + "text": "}" + }, + { + "id": 2, + "logprob": 0.0, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}" +} diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics.json new file mode 100644 index 00000000..90fb6dcc --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics.json @@ -0,0 +1,168 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4911, + "logprob": -6.9765625, + "text": "User" + }, + { + "id": 29901, + "logprob": -0.0059432983, + "text": ":" + }, + { + "id": 32000, + "logprob": -0.8408203, + "text": "" + }, + { + "id": 32001, + "logprob": -9.906292e-05, + "text": "" + }, + { + "id": 32000, + "logprob": -2.3841858e-07, + "text": "" + }, + { + "id": 1815, + "logprob": -4.1679688, + "text": "Can" + }, + { + "id": 366, + "logprob": -0.014099121, + "text": "you" + }, + { + "id": 2649, + "logprob": -4.4609375, + "text": "tell" + }, + { + "id": 592, + "logprob": -0.29882812, + "text": "me" + }, + { + "id": 263, + "logprob": -4.1445312, + "text": "a" + }, + { + "id": 1407, + "logprob": -9.3828125, + "text": "very" + }, + { + "id": 3273, + "logprob": -1.9736328, + "text": "short" + }, + { + "id": 5828, + "logprob": -0.2800293, + "text": "story" + }, + { + "id": 2729, + "logprob": -3.5625, + "text": "based" + }, + { + "id": 373, + "logprob": -0.0006427765, + "text": "on" + }, + { + "id": 278, + "logprob": -0.13952637, + "text": "the" + }, + { + "id": 1967, + "logprob": -0.068115234, + "text": "image" + }, + { + "id": 29973, + "logprob": -0.16357422, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 32002, + "logprob": -0.0026474, + "special": true, + "text": "" + }, + { + "id": 29871, + "logprob": -8.547306e-05, + "special": false, + "text": " " + }, + { + "id": 13, + "logprob": -1.7881393e-05, + "special": false, + "text": "\n" + }, + { + "id": 7900, + "logprob": -3.0994415e-06, + "special": false, + "text": "Ass" + }, + { + "id": 22137, + "logprob": 0.0, + "special": false, + "text": "istant" + }, + { + "id": 29901, + "logprob": -3.2186508e-06, + "special": false, + "text": ":" + }, + { + "id": 319, + "logprob": -0.92529297, + "special": false, + "text": " A" + }, + { + "id": 696, + "logprob": -1.1269531, + "special": false, + "text": " ro" + }, + { + "id": 15664, + "logprob": -0.00029492378, + "special": false, + "text": "oster" + }, + { + "id": 15028, + "logprob": -1.1855469, + "special": false, + "text": " stands" + } + ] + }, + "generated_text": " \nAssistant: A rooster stands" +} diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json new file mode 100644 index 00000000..21d6161b --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json @@ -0,0 +1,674 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4911, + "logprob": -6.9804688, + "text": "User" + }, + { + "id": 29901, + "logprob": -0.006122589, + "text": ":" + }, + { + "id": 32000, + "logprob": -0.8417969, + "text": "" + }, + { + "id": 32001, + "logprob": -9.918213e-05, + "text": "" + }, + { + "id": 32000, + "logprob": -2.3841858e-07, + "text": "" + }, + { + "id": 1815, + "logprob": -4.1679688, + "text": "Can" + }, + { + "id": 366, + "logprob": -0.014091492, + "text": "you" + }, + { + "id": 2649, + "logprob": -4.4726562, + "text": "tell" + }, + { + "id": 592, + "logprob": -0.2998047, + "text": "me" + }, + { + "id": 263, + "logprob": -4.15625, + "text": "a" + }, + { + "id": 1407, + "logprob": -9.3828125, + "text": "very" + }, + { + "id": 3273, + "logprob": -1.9716797, + "text": "short" + }, + { + "id": 5828, + "logprob": -0.27734375, + "text": "story" + }, + { + "id": 2729, + "logprob": -3.5605469, + "text": "based" + }, + { + "id": 373, + "logprob": -0.00064468384, + "text": "on" + }, + { + "id": 278, + "logprob": -0.14160156, + "text": "the" + }, + { + "id": 1967, + "logprob": -0.06915283, + "text": "image" + }, + { + "id": 29973, + "logprob": -0.16381836, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 32002, + "logprob": -0.0026664734, + "special": true, + "text": "" + }, + { + "id": 29871, + "logprob": -8.583069e-05, + "special": false, + "text": " " + }, + { + "id": 13, + "logprob": -1.8119812e-05, + "special": false, + "text": "\n" + }, + { + "id": 7900, + "logprob": -2.9802322e-06, + "special": false, + "text": "Ass" + }, + { + "id": 22137, + "logprob": 0.0, + "special": false, + "text": "istant" + }, + { + "id": 29901, + "logprob": -3.2186508e-06, + "special": false, + "text": ":" + }, + { + "id": 319, + "logprob": -0.9301758, + "special": false, + "text": " A" + }, + { + "id": 696, + "logprob": -1.1279297, + "special": false, + "text": " ro" + }, + { + "id": 15664, + "logprob": -0.0002939701, + "special": false, + "text": "oster" + }, + { + "id": 15028, + "logprob": -1.1865234, + "special": false, + "text": " stands" + } + ] + }, + "generated_text": " \nAssistant: A rooster stands" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4911, + "logprob": -6.9804688, + "text": "User" + }, + { + "id": 29901, + "logprob": -0.006122589, + "text": ":" + }, + { + "id": 32000, + "logprob": -0.8417969, + "text": "" + }, + { + "id": 32001, + "logprob": -9.942055e-05, + "text": "" + }, + { + "id": 32000, + "logprob": -2.3841858e-07, + "text": "" + }, + { + "id": 1815, + "logprob": -4.1679688, + "text": "Can" + }, + { + "id": 366, + "logprob": -0.014091492, + "text": "you" + }, + { + "id": 2649, + "logprob": -4.4726562, + "text": "tell" + }, + { + "id": 592, + "logprob": -0.2998047, + "text": "me" + }, + { + "id": 263, + "logprob": -4.15625, + "text": "a" + }, + { + "id": 1407, + "logprob": -9.3828125, + "text": "very" + }, + { + "id": 3273, + "logprob": -1.9716797, + "text": "short" + }, + { + "id": 5828, + "logprob": -0.27734375, + "text": "story" + }, + { + "id": 2729, + "logprob": -3.5605469, + "text": "based" + }, + { + "id": 373, + "logprob": -0.0006451607, + "text": "on" + }, + { + "id": 278, + "logprob": -0.14160156, + "text": "the" + }, + { + "id": 1967, + "logprob": -0.06915283, + "text": "image" + }, + { + "id": 29973, + "logprob": -0.16381836, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 32002, + "logprob": -0.0026664734, + "special": true, + "text": "" + }, + { + "id": 29871, + "logprob": -8.571148e-05, + "special": false, + "text": " " + }, + { + "id": 13, + "logprob": -1.8119812e-05, + "special": false, + "text": "\n" + }, + { + "id": 7900, + "logprob": -3.0994415e-06, + "special": false, + "text": "Ass" + }, + { + "id": 22137, + "logprob": 0.0, + "special": false, + "text": "istant" + }, + { + "id": 29901, + "logprob": -3.0994415e-06, + "special": false, + "text": ":" + }, + { + "id": 319, + "logprob": -0.9301758, + "special": false, + "text": " A" + }, + { + "id": 696, + "logprob": -1.1279297, + "special": false, + "text": " ro" + }, + { + "id": 15664, + "logprob": -0.0002939701, + "special": false, + "text": "oster" + }, + { + "id": 15028, + "logprob": -1.1865234, + "special": false, + "text": " stands" + } + ] + }, + "generated_text": " \nAssistant: A rooster stands" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4911, + "logprob": -6.9804688, + "text": "User" + }, + { + "id": 29901, + "logprob": -0.006122589, + "text": ":" + }, + { + "id": 32000, + "logprob": -0.8417969, + "text": "" + }, + { + "id": 32001, + "logprob": -9.918213e-05, + "text": "" + }, + { + "id": 32000, + "logprob": -2.3841858e-07, + "text": "" + }, + { + "id": 1815, + "logprob": -4.1679688, + "text": "Can" + }, + { + "id": 366, + "logprob": -0.014091492, + "text": "you" + }, + { + "id": 2649, + "logprob": -4.4726562, + "text": "tell" + }, + { + "id": 592, + "logprob": -0.2998047, + "text": "me" + }, + { + "id": 263, + "logprob": -4.15625, + "text": "a" + }, + { + "id": 1407, + "logprob": -9.3828125, + "text": "very" + }, + { + "id": 3273, + "logprob": -1.9716797, + "text": "short" + }, + { + "id": 5828, + "logprob": -0.27734375, + "text": "story" + }, + { + "id": 2729, + "logprob": -3.5605469, + "text": "based" + }, + { + "id": 373, + "logprob": -0.00064468384, + "text": "on" + }, + { + "id": 278, + "logprob": -0.14160156, + "text": "the" + }, + { + "id": 1967, + "logprob": -0.06915283, + "text": "image" + }, + { + "id": 29973, + "logprob": -0.16381836, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 32002, + "logprob": -0.0026664734, + "special": true, + "text": "" + }, + { + "id": 29871, + "logprob": -8.59499e-05, + "special": false, + "text": " " + }, + { + "id": 13, + "logprob": -1.8119812e-05, + "special": false, + "text": "\n" + }, + { + "id": 7900, + "logprob": -3.0994415e-06, + "special": false, + "text": "Ass" + }, + { + "id": 22137, + "logprob": 0.0, + "special": false, + "text": "istant" + }, + { + "id": 29901, + "logprob": -3.0994415e-06, + "special": false, + "text": ":" + }, + { + "id": 319, + "logprob": -0.9301758, + "special": false, + "text": " A" + }, + { + "id": 696, + "logprob": -1.1279297, + "special": false, + "text": " ro" + }, + { + "id": 15664, + "logprob": -0.0002939701, + "special": false, + "text": "oster" + }, + { + "id": 15028, + "logprob": -1.1865234, + "special": false, + "text": " stands" + } + ] + }, + "generated_text": " \nAssistant: A rooster stands" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4911, + "logprob": -6.9804688, + "text": "User" + }, + { + "id": 29901, + "logprob": -0.006122589, + "text": ":" + }, + { + "id": 32000, + "logprob": -0.8417969, + "text": "" + }, + { + "id": 32001, + "logprob": -9.942055e-05, + "text": "" + }, + { + "id": 32000, + "logprob": -2.3841858e-07, + "text": "" + }, + { + "id": 1815, + "logprob": -4.1679688, + "text": "Can" + }, + { + "id": 366, + "logprob": -0.014091492, + "text": "you" + }, + { + "id": 2649, + "logprob": -4.4726562, + "text": "tell" + }, + { + "id": 592, + "logprob": -0.2998047, + "text": "me" + }, + { + "id": 263, + "logprob": -4.15625, + "text": "a" + }, + { + "id": 1407, + "logprob": -9.3828125, + "text": "very" + }, + { + "id": 3273, + "logprob": -1.9716797, + "text": "short" + }, + { + "id": 5828, + "logprob": -0.27734375, + "text": "story" + }, + { + "id": 2729, + "logprob": -3.5605469, + "text": "based" + }, + { + "id": 373, + "logprob": -0.0006451607, + "text": "on" + }, + { + "id": 278, + "logprob": -0.14160156, + "text": "the" + }, + { + "id": 1967, + "logprob": -0.06915283, + "text": "image" + }, + { + "id": 29973, + "logprob": -0.16381836, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 32002, + "logprob": -0.0026664734, + "special": true, + "text": "" + }, + { + "id": 29871, + "logprob": -8.571148e-05, + "special": false, + "text": " " + }, + { + "id": 13, + "logprob": -1.8119812e-05, + "special": false, + "text": "\n" + }, + { + "id": 7900, + "logprob": -3.0994415e-06, + "special": false, + "text": "Ass" + }, + { + "id": 22137, + "logprob": 0.0, + "special": false, + "text": "istant" + }, + { + "id": 29901, + "logprob": -3.0994415e-06, + "special": false, + "text": ":" + }, + { + "id": 319, + "logprob": -0.9301758, + "special": false, + "text": " A" + }, + { + "id": 696, + "logprob": -1.1279297, + "special": false, + "text": " ro" + }, + { + "id": 15664, + "logprob": -0.0002939701, + "special": false, + "text": "oster" + }, + { + "id": 15028, + "logprob": -1.1865234, + "special": false, + "text": " stands" + } + ] + }, + "generated_text": " \nAssistant: A rooster stands" + } +] diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json new file mode 100644 index 00000000..45601505 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -8.5625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.78125, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 288, + "logprob": -0.2854004, + "special": false, + "text": "ing" + }, + { + "id": 264, + "logprob": -0.37573242, + "special": false, + "text": " a" + }, + { + "id": 633, + "logprob": -0.09301758, + "special": false, + "text": " new" + }, + { + "id": 4480, + "logprob": -0.3322754, + "special": false, + "text": " feature" + }, + { + "id": 297, + "logprob": -0.8510742, + "special": false, + "text": " in" + }, + { + "id": 272, + "logprob": -0.13464355, + "special": false, + "text": " the" + }, + { + "id": 2039, + "logprob": 0.0, + "special": false, + "text": " game" + }, + { + "id": 28723, + "logprob": -0.89990234, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "Test requesting a new feature in the game.\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json new file mode 100644 index 00000000..4bc90896 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json @@ -0,0 +1,7018 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -5.2421875, + "text": "User" + }, + { + "id": 28747, + "logprob": -6.9570312, + "text": ":" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -23.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.7265625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -18.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -20.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.75, + "text": "" + }, + { + "id": 32001, + "logprob": -21.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8828125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -20.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.5, + "text": "" + }, + { + "id": 32001, + "logprob": -17.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4296875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -13.6328125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4140625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -2.0429688, + "text": "" + }, + { + "id": 12018, + "logprob": -12.03125, + "text": "Write" + }, + { + "id": 528, + "logprob": -10.25, + "text": "me" + }, + { + "id": 264, + "logprob": -0.10437012, + "text": "a" + }, + { + "id": 2485, + "logprob": -4.5742188, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.2277832, + "text": "story" + }, + { + "id": 32002, + "logprob": -10.84375, + "text": "" + }, + { + "id": 259, + "logprob": -20.1875, + "text": " " + }, + { + "id": 13, + "logprob": -8.7578125, + "text": "\n" + }, + { + "id": 7226, + "logprob": -10.421875, + "text": "Ass" + }, + { + "id": 11143, + "logprob": -13.640625, + "text": "istant" + }, + { + "id": 28747, + "logprob": -0.005619049, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.12939453, + "special": false, + "text": " A" + }, + { + "id": 13088, + "logprob": -0.6660156, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.29638672, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.05960083, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.26953125, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.1427002, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.040649414, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.0002708435, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.09429932, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.006931305, + "special": false, + "text": "." + } + ], + "top_tokens": null + }, + "generated_text": " A chicken is sitting on a pile of money." + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -5.234375, + "text": "User" + }, + { + "id": 28747, + "logprob": -6.9648438, + "text": ":" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -23.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.7265625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -18.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -20.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.75, + "text": "" + }, + { + "id": 32001, + "logprob": -21.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -20.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.5, + "text": "" + }, + { + "id": 32001, + "logprob": -17.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4296875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -2.0429688, + "text": "" + }, + { + "id": 12018, + "logprob": -12.03125, + "text": "Write" + }, + { + "id": 528, + "logprob": -10.2578125, + "text": "me" + }, + { + "id": 264, + "logprob": -0.10418701, + "text": "a" + }, + { + "id": 2485, + "logprob": -4.5664062, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.22741699, + "text": "story" + }, + { + "id": 32002, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 259, + "logprob": -20.203125, + "text": " " + }, + { + "id": 13, + "logprob": -8.7421875, + "text": "\n" + }, + { + "id": 7226, + "logprob": -10.4140625, + "text": "Ass" + }, + { + "id": 11143, + "logprob": -13.6328125, + "text": "istant" + }, + { + "id": 28747, + "logprob": -0.005580902, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.1295166, + "special": false, + "text": " A" + }, + { + "id": 13088, + "logprob": -0.6669922, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.29711914, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.059936523, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.27124023, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.140625, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.04058838, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.00027012825, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.09503174, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.006942749, + "special": false, + "text": "." + } + ], + "top_tokens": null + }, + "generated_text": " A chicken is sitting on a pile of money." + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -5.2460938, + "text": "User" + }, + { + "id": 28747, + "logprob": -6.9570312, + "text": ":" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -23.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.7265625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -14.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -17.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -18.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -20.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -20.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.5, + "text": "" + }, + { + "id": 32001, + "logprob": -17.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4296875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -21.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4140625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -2.0429688, + "text": "" + }, + { + "id": 12018, + "logprob": -12.0390625, + "text": "Write" + }, + { + "id": 528, + "logprob": -10.25, + "text": "me" + }, + { + "id": 264, + "logprob": -0.10443115, + "text": "a" + }, + { + "id": 2485, + "logprob": -4.5742188, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.22729492, + "text": "story" + }, + { + "id": 32002, + "logprob": -10.84375, + "text": "" + }, + { + "id": 259, + "logprob": -20.1875, + "text": " " + }, + { + "id": 13, + "logprob": -8.7578125, + "text": "\n" + }, + { + "id": 7226, + "logprob": -10.4140625, + "text": "Ass" + }, + { + "id": 11143, + "logprob": -13.6328125, + "text": "istant" + }, + { + "id": 28747, + "logprob": -0.0056533813, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.12963867, + "special": false, + "text": " A" + }, + { + "id": 13088, + "logprob": -0.6660156, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.29516602, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.060028076, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.27075195, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.1427002, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.04067993, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.000269413, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.09387207, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.0069236755, + "special": false, + "text": "." + } + ], + "top_tokens": null + }, + "generated_text": " A chicken is sitting on a pile of money." + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -5.2421875, + "text": "User" + }, + { + "id": 28747, + "logprob": -6.9570312, + "text": ":" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -22.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -23.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.5, + "text": "" + }, + { + "id": 32001, + "logprob": -19.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.328125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -23.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -23.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -23.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.75, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.34375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0, + "text": "" + }, + { + "id": 32001, + "logprob": -18.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.25, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4140625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.7265625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.578125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.25, + "text": "" + }, + { + "id": 32001, + "logprob": -17.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.8515625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.71875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.5, + "text": "" + }, + { + "id": 32001, + "logprob": -18.296875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -20.25, + "text": "" + }, + { + "id": 32001, + "logprob": -20.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.609375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5, + "text": "" + }, + { + "id": 32001, + "logprob": -20.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.265625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.03125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.703125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.203125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -22.625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.5625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.15625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.96875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.78125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.828125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.984375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.6875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.21875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.796875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.8125, + "text": "" + }, + { + "id": 32001, + "logprob": -19.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.09375, + "text": "" + }, + { + "id": 32001, + "logprob": -18.75, + "text": "" + }, + { + "id": 32001, + "logprob": -18.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.1171875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.0, + "text": "" + }, + { + "id": 32001, + "logprob": -20.75, + "text": "" + }, + { + "id": 32001, + "logprob": -16.25, + "text": "" + }, + { + "id": 32001, + "logprob": -19.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -21.59375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.671875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.921875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.5, + "text": "" + }, + { + "id": 32001, + "logprob": -17.90625, + "text": "" + }, + { + "id": 32001, + "logprob": -22.1875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.734375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.40625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -22.28125, + "text": "" + }, + { + "id": 32001, + "logprob": -18.515625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4296875, + "text": "" + }, + { + "id": 32001, + "logprob": -19.765625, + "text": "" + }, + { + "id": 32001, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.46875, + "text": "" + }, + { + "id": 32001, + "logprob": -18.875, + "text": "" + }, + { + "id": 32001, + "logprob": -20.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32001, + "logprob": -16.4375, + "text": "" + }, + { + "id": 32001, + "logprob": -21.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.234375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.140625, + "text": "" + }, + { + "id": 32001, + "logprob": -21.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -20.015625, + "text": "" + }, + { + "id": 32001, + "logprob": -18.84375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.421875, + "text": "" + }, + { + "id": 32001, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32001, + "logprob": -15.4140625, + "text": "" + }, + { + "id": 32001, + "logprob": -17.546875, + "text": "" + }, + { + "id": 32001, + "logprob": -21.859375, + "text": "" + }, + { + "id": 32001, + "logprob": -15.65625, + "text": "" + }, + { + "id": 32001, + "logprob": -16.484375, + "text": "" + }, + { + "id": 32001, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32001, + "logprob": -19.9375, + "text": "" + }, + { + "id": 32001, + "logprob": -17.875, + "text": "" + }, + { + "id": 32001, + "logprob": -17.453125, + "text": "" + }, + { + "id": 32001, + "logprob": -20.390625, + "text": "" + }, + { + "id": 32001, + "logprob": -19.171875, + "text": "" + }, + { + "id": 32001, + "logprob": -15.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -2.0429688, + "text": "" + }, + { + "id": 12018, + "logprob": -12.03125, + "text": "Write" + }, + { + "id": 528, + "logprob": -10.25, + "text": "me" + }, + { + "id": 264, + "logprob": -0.10437012, + "text": "a" + }, + { + "id": 2485, + "logprob": -4.578125, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.22924805, + "text": "story" + }, + { + "id": 32002, + "logprob": -10.84375, + "text": "" + }, + { + "id": 259, + "logprob": -20.171875, + "text": " " + }, + { + "id": 13, + "logprob": -8.765625, + "text": "\n" + }, + { + "id": 7226, + "logprob": -10.4140625, + "text": "Ass" + }, + { + "id": 11143, + "logprob": -13.640625, + "text": "istant" + }, + { + "id": 28747, + "logprob": -0.005744934, + "text": ":" + } + ], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.12976074, + "special": false, + "text": " A" + }, + { + "id": 13088, + "logprob": -0.66308594, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.29541016, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.05996704, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.27075195, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.14160156, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.040863037, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.00027036667, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.093322754, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.006931305, + "special": false, + "text": "." + } + ], + "top_tokens": null + }, + "generated_text": " A chicken is sitting on a pile of money." + } +] diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json new file mode 100644 index 00000000..a3b18d0a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.13000488, + "special": false, + "text": " A" + }, + { + "id": 13088, + "logprob": -0.6713867, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.2980957, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.060638428, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.27319336, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.140625, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.040405273, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.0002708435, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.095336914, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.0068359375, + "special": false, + "text": "." + } + ], + "top_tokens": null + }, + "generated_text": " A chicken is sitting on a pile of money." +} diff --git a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json new file mode 100644 index 00000000..e9d3e5ef --- /dev/null +++ b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json @@ -0,0 +1,65 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "stop_sequence", + "generated_tokens": 6, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -10.5, + "text": "Test" + }, + { + "id": 2159, + "logprob": -12.140625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -1.0654297, + "special": false, + "text": "\n" + }, + { + "id": 1014, + "logprob": -2.7460938, + "special": false, + "text": "The" + }, + { + "id": 6032, + "logprob": -1.359375, + "special": false, + "text": " purpose" + }, + { + "id": 302, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 456, + "logprob": 0.0, + "special": false, + "text": " this" + }, + { + "id": 1369, + "logprob": -0.40063477, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "Test request\nThe purpose of this test" +} diff --git a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json new file mode 100644 index 00000000..2007c0f2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json @@ -0,0 +1,59178 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -2.3886719, + "text": "User" + }, + { + "id": 28747, + "logprob": -12.328125, + "text": ":" + }, + { + "id": 32000, + "logprob": -10.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -9.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -9.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.25, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.25, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 2418, + "logprob": -19.0625, + "text": "Can" + }, + { + "id": 368, + "logprob": -0.19726562, + "text": "you" + }, + { + "id": 1912, + "logprob": -1.4990234, + "text": "tell" + }, + { + "id": 528, + "logprob": -0.31152344, + "text": "me" + }, + { + "id": 264, + "logprob": -2.6367188, + "text": "a" + }, + { + "id": 1215, + "logprob": -9.1015625, + "text": "very" + }, + { + "id": 2485, + "logprob": -0.9941406, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.46118164, + "text": "story" + }, + { + "id": 2818, + "logprob": -3.3183594, + "text": "based" + }, + { + "id": 356, + "logprob": -0.029129028, + "text": "on" + }, + { + "id": 272, + "logprob": -0.9902344, + "text": "the" + }, + { + "id": 3469, + "logprob": -0.29052734, + "text": "image" + }, + { + "id": 28804, + "logprob": -0.43188477, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.0076828003, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.20092773, + "special": false, + "text": "\n" + }, + { + "id": 16114, + "logprob": -1.2587891, + "special": false, + "text": "Once" + }, + { + "id": 3714, + "logprob": -0.20861816, + "special": false, + "text": " upon" + }, + { + "id": 264, + "logprob": -0.0017719269, + "special": false, + "text": " a" + }, + { + "id": 727, + "logprob": -0.011909485, + "special": false, + "text": " time" + }, + { + "id": 28725, + "logprob": -0.17529297, + "special": false, + "text": "," + }, + { + "id": 736, + "logprob": -0.9082031, + "special": false, + "text": " there" + }, + { + "id": 403, + "logprob": -0.057525635, + "special": false, + "text": " was" + }, + { + "id": 264, + "logprob": -0.009651184, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nOnce upon a time, there was a" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -2.3886719, + "text": "User" + }, + { + "id": 28747, + "logprob": -12.328125, + "text": ":" + }, + { + "id": 32000, + "logprob": -10.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -9.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -9.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.25, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.25, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 2418, + "logprob": -19.0625, + "text": "Can" + }, + { + "id": 368, + "logprob": -0.19726562, + "text": "you" + }, + { + "id": 1912, + "logprob": -1.4990234, + "text": "tell" + }, + { + "id": 528, + "logprob": -0.31152344, + "text": "me" + }, + { + "id": 264, + "logprob": -2.6367188, + "text": "a" + }, + { + "id": 1215, + "logprob": -9.1015625, + "text": "very" + }, + { + "id": 2485, + "logprob": -0.9941406, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.46118164, + "text": "story" + }, + { + "id": 2818, + "logprob": -3.3183594, + "text": "based" + }, + { + "id": 356, + "logprob": -0.029129028, + "text": "on" + }, + { + "id": 272, + "logprob": -0.9902344, + "text": "the" + }, + { + "id": 3469, + "logprob": -0.29052734, + "text": "image" + }, + { + "id": 28804, + "logprob": -0.43188477, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.0076828003, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.19958496, + "special": false, + "text": "\n" + }, + { + "id": 16114, + "logprob": -1.2587891, + "special": false, + "text": "Once" + }, + { + "id": 3714, + "logprob": -0.20861816, + "special": false, + "text": " upon" + }, + { + "id": 264, + "logprob": -0.0017719269, + "special": false, + "text": " a" + }, + { + "id": 727, + "logprob": -0.011749268, + "special": false, + "text": " time" + }, + { + "id": 28725, + "logprob": -0.17529297, + "special": false, + "text": "," + }, + { + "id": 736, + "logprob": -0.9086914, + "special": false, + "text": " there" + }, + { + "id": 403, + "logprob": -0.056732178, + "special": false, + "text": " was" + }, + { + "id": 264, + "logprob": -0.00970459, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nOnce upon a time, there was a" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -2.3886719, + "text": "User" + }, + { + "id": 28747, + "logprob": -12.328125, + "text": ":" + }, + { + "id": 32000, + "logprob": -10.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -9.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -9.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.25, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.25, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 2418, + "logprob": -19.0625, + "text": "Can" + }, + { + "id": 368, + "logprob": -0.19726562, + "text": "you" + }, + { + "id": 1912, + "logprob": -1.4990234, + "text": "tell" + }, + { + "id": 528, + "logprob": -0.31152344, + "text": "me" + }, + { + "id": 264, + "logprob": -2.6367188, + "text": "a" + }, + { + "id": 1215, + "logprob": -9.1015625, + "text": "very" + }, + { + "id": 2485, + "logprob": -0.9941406, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.46118164, + "text": "story" + }, + { + "id": 2818, + "logprob": -3.3183594, + "text": "based" + }, + { + "id": 356, + "logprob": -0.029129028, + "text": "on" + }, + { + "id": 272, + "logprob": -0.9902344, + "text": "the" + }, + { + "id": 3469, + "logprob": -0.29052734, + "text": "image" + }, + { + "id": 28804, + "logprob": -0.43188477, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.0076828003, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.20092773, + "special": false, + "text": "\n" + }, + { + "id": 16114, + "logprob": -1.2587891, + "special": false, + "text": "Once" + }, + { + "id": 3714, + "logprob": -0.20861816, + "special": false, + "text": " upon" + }, + { + "id": 264, + "logprob": -0.0017719269, + "special": false, + "text": " a" + }, + { + "id": 727, + "logprob": -0.011909485, + "special": false, + "text": " time" + }, + { + "id": 28725, + "logprob": -0.17529297, + "special": false, + "text": "," + }, + { + "id": 736, + "logprob": -0.9082031, + "special": false, + "text": " there" + }, + { + "id": 403, + "logprob": -0.057525635, + "special": false, + "text": " was" + }, + { + "id": 264, + "logprob": -0.009651184, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nOnce upon a time, there was a" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1247, + "logprob": -2.3886719, + "text": "User" + }, + { + "id": 28747, + "logprob": -12.328125, + "text": ":" + }, + { + "id": 32000, + "logprob": -10.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -18.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -9.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -9.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -18.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -9.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.25, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.25, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.25, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -9.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2578125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.25, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -17.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -19.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.328125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.21875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -16.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -17.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.46875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.5, + "text": "" + }, + { + "id": 32000, + "logprob": -12.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.8671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9609375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -18.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.640625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8359375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.25, + "text": "" + }, + { + "id": 32000, + "logprob": -14.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.2421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -16.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.203125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6328125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.75, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4453125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.53125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.484375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -10.6953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1640625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.84375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -18.0, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5859375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7578125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.015625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.15625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6015625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5, + "text": "" + }, + { + "id": 32000, + "logprob": -10.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.1171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.1328125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -10.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4296875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0859375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -19.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 32000, + "logprob": -11.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.859375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8984375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.703125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.4921875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.28125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.96875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0234375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3046875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.078125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7265625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0703125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.1484375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.1796875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4140625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -14.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.453125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.2890625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5546875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -14.6171875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3671875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.75, + "text": "" + }, + { + "id": 32000, + "logprob": -11.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -14.5625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9453125, + "text": "" + }, + { + "id": 32000, + "logprob": -10.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.234375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.953125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6484375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5703125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.265625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -15.1953125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.7421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.09375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.0546875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.59375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.3515625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.90625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.609375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.671875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -15.2265625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.78125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.6875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.796875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.03125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.875, + "text": "" + }, + { + "id": 32000, + "logprob": -16.515625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.7734375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.4609375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.171875, + "text": "" + }, + { + "id": 32000, + "logprob": -11.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.4375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.828125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.2734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -14.3984375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.578125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.3359375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.984375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.421875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.34375, + "text": "" + }, + { + "id": 32000, + "logprob": -12.8828125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.890625, + "text": "" + }, + { + "id": 32000, + "logprob": -13.3203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.2109375, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9765625, + "text": "" + }, + { + "id": 32000, + "logprob": -15.140625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.0078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.0390625, + "text": "" + }, + { + "id": 32000, + "logprob": -14.40625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.046875, + "text": "" + }, + { + "id": 32000, + "logprob": -13.8203125, + "text": "" + }, + { + "id": 32000, + "logprob": -13.5078125, + "text": "" + }, + { + "id": 32000, + "logprob": -11.734375, + "text": "" + }, + { + "id": 32000, + "logprob": -13.390625, + "text": "" + }, + { + "id": 32000, + "logprob": -17.3125, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5234375, + "text": "" + }, + { + "id": 32000, + "logprob": -17.625, + "text": "" + }, + { + "id": 32000, + "logprob": -11.9296875, + "text": "" + }, + { + "id": 32000, + "logprob": -12.71875, + "text": "" + }, + { + "id": 32000, + "logprob": -15.9140625, + "text": "" + }, + { + "id": 32000, + "logprob": -16.65625, + "text": "" + }, + { + "id": 32000, + "logprob": -12.5, + "text": "" + }, + { + "id": 2418, + "logprob": -19.0625, + "text": "Can" + }, + { + "id": 368, + "logprob": -0.19726562, + "text": "you" + }, + { + "id": 1912, + "logprob": -1.4990234, + "text": "tell" + }, + { + "id": 528, + "logprob": -0.31152344, + "text": "me" + }, + { + "id": 264, + "logprob": -2.6367188, + "text": "a" + }, + { + "id": 1215, + "logprob": -9.1015625, + "text": "very" + }, + { + "id": 2485, + "logprob": -0.9941406, + "text": "short" + }, + { + "id": 2838, + "logprob": -0.46118164, + "text": "story" + }, + { + "id": 2818, + "logprob": -3.3183594, + "text": "based" + }, + { + "id": 356, + "logprob": -0.029129028, + "text": "on" + }, + { + "id": 272, + "logprob": -0.9902344, + "text": "the" + }, + { + "id": 3469, + "logprob": -0.29052734, + "text": "image" + }, + { + "id": 28804, + "logprob": -0.43188477, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.0076828003, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.19958496, + "special": false, + "text": "\n" + }, + { + "id": 16114, + "logprob": -1.2587891, + "special": false, + "text": "Once" + }, + { + "id": 3714, + "logprob": -0.20861816, + "special": false, + "text": " upon" + }, + { + "id": 264, + "logprob": -0.0017719269, + "special": false, + "text": " a" + }, + { + "id": 727, + "logprob": -0.011749268, + "special": false, + "text": " time" + }, + { + "id": 28725, + "logprob": -0.17529297, + "special": false, + "text": "," + }, + { + "id": 736, + "logprob": -0.9086914, + "special": false, + "text": " there" + }, + { + "id": 403, + "logprob": -0.056732178, + "special": false, + "text": " was" + }, + { + "id": 264, + "logprob": -0.00970459, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nOnce upon a time, there was a" + } +] diff --git a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json new file mode 100644 index 00000000..f0f2ee9e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.00756073, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.20117188, + "special": false, + "text": "\n" + }, + { + "id": 16114, + "logprob": -1.2597656, + "special": false, + "text": "Once" + }, + { + "id": 3714, + "logprob": -0.20825195, + "special": false, + "text": " upon" + }, + { + "id": 264, + "logprob": -0.00178051, + "special": false, + "text": " a" + }, + { + "id": 727, + "logprob": -0.011955261, + "special": false, + "text": " time" + }, + { + "id": 28725, + "logprob": -0.17541504, + "special": false, + "text": "," + }, + { + "id": 736, + "logprob": -0.91308594, + "special": false, + "text": " there" + }, + { + "id": 403, + "logprob": -0.058410645, + "special": false, + "text": " was" + }, + { + "id": 264, + "logprob": -0.009689331, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nOnce upon a time, there was a" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json new file mode 100644 index 00000000..eaba5078 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.37890625, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.26953125, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1953125, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.53515625, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.6796875, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3125, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0028533936, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.265625, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json new file mode 100644 index 00000000..85e9a9e0 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2502, + "logprob": null, + "text": " red" + }, + { + "id": 13, + "logprob": -2.734375, + "text": "," + }, + { + "id": 8862, + "logprob": -3.6875, + "text": " yellow" + }, + { + "id": 13, + "logprob": -0.40234375, + "text": "," + }, + { + "id": 209, + "logprob": -8.25, + "text": " " + } + ], + "seed": 0, + "tokens": [ + { + "id": 187, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 395, + "logprob": -0.3125, + "special": false, + "text": "and" + }, + { + "id": 4797, + "logprob": 0.0, + "special": false, + "text": " blue" + }, + { + "id": 9830, + "logprob": -1.65625, + "special": false, + "text": " colors" + }, + { + "id": 15, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 329, + "logprob": -2.4375, + "special": false, + "text": " A" + }, + { + "id": 1180, + "logprob": -1.953125, + "special": false, + "text": " number" + }, + { + "id": 273, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 1027, + "logprob": -1.5546875, + "special": false, + "text": " different" + }, + { + "id": 3295, + "logprob": -0.97265625, + "special": false, + "text": " color" + } + ], + "top_tokens": null + }, + "generated_text": "blue, red, yellow, \nand blue colors. A number of different color" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json new file mode 100644 index 00000000..4921c14b --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json @@ -0,0 +1,398 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.83984375, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.84375, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.25, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.37890625, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.4296875, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.078125, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.515625, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.6015625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.65625, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.109375, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.328125, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0032653809, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.28125, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.80078125, + "text": " is" + }, + { + "id": 18147, + "logprob": -13.25, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.828125, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1953125, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.296875, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.3359375, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.2578125, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5546875, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.62890625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.64453125, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.078125, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.28125, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0030670166, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.3125, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.80078125, + "text": " is" + }, + { + "id": 18147, + "logprob": -13.25, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.828125, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1953125, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.296875, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.3359375, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.2578125, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5546875, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.62890625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.64453125, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.078125, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.28125, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0030670166, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.3125, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.80078125, + "text": " is" + }, + { + "id": 18147, + "logprob": -13.25, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.828125, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1953125, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.296875, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.3359375, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.2578125, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5546875, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.62890625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.64453125, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.078125, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.28125, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0030670166, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.3125, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + } +] diff --git a/integration-tests/models/__snapshots__/test_mpt/test_mpt.json b/integration-tests/models/__snapshots__/test_mpt/test_mpt.json new file mode 100644 index 00000000..abbbf03c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mpt/test_mpt.json @@ -0,0 +1,140 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5117188, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.96875, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.953125, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.94189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5830078, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3105469, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.3215332, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5566406, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.6074219, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.6923828, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5263672, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.8544922, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6118164, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.055877686, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0537109, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.0115737915, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9111328, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4589844, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.4853516, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021636963, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" +} diff --git a/integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json b/integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json new file mode 100644 index 00000000..e3bc57ed --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json @@ -0,0 +1,562 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5117188, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.96875, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.953125, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.94189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5830078, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3183594, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.32617188, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5742188, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.6015625, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.67822266, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5395508, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.8623047, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6020508, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.0552063, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0742188, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011405945, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9165039, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4501953, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.4960938, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.02116394, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.984375, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.96875, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.93359375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5800781, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3242188, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.31835938, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5644531, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.5957031, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.68603516, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5258789, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.859375, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6166992, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.056762695, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0703125, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011428833, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9213867, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4726562, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.5039062, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021652222, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.984375, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.96875, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.93359375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5800781, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3242188, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.31835938, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5644531, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.5957031, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.68603516, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5258789, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.859375, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6166992, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.056762695, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0703125, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011428833, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9213867, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4726562, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.5039062, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021652222, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 17, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -1.5, + "text": " is" + }, + { + "id": 18147, + "logprob": -8.984375, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -1.96875, + "text": " Learning" + }, + { + "id": 32, + "logprob": -0.93359375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 428, + "logprob": -1.5800781, + "special": false, + "text": " -" + }, + { + "id": 18147, + "logprob": -3.3242188, + "special": false, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -0.31835938, + "special": false, + "text": " Learning" + }, + { + "id": 187, + "logprob": -2.5644531, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.5957031, + "special": false, + "text": "Deep" + }, + { + "id": 20727, + "logprob": -0.69628906, + "special": false, + "text": " Learning" + }, + { + "id": 310, + "logprob": -0.68603516, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.5258789, + "special": false, + "text": " a" + }, + { + "id": 749, + "logprob": -1.859375, + "special": false, + "text": " sub" + }, + { + "id": 3423, + "logprob": -0.6166992, + "special": false, + "text": "field" + }, + { + "id": 273, + "logprob": -0.056762695, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.0703125, + "special": false, + "text": " machine" + }, + { + "id": 4715, + "logprob": -0.011428833, + "special": false, + "text": " learning" + }, + { + "id": 326, + "logprob": -0.9213867, + "special": false, + "text": " that" + }, + { + "id": 4648, + "logprob": -1.4726562, + "special": false, + "text": " uses" + }, + { + "id": 13345, + "logprob": -1.5039062, + "special": false, + "text": " artificial" + }, + { + "id": 11454, + "logprob": -0.021652222, + "special": false, + "text": " neural" + } + ] + }, + "generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + } +] diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json new file mode 100644 index 00000000..c1cd24cd --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json @@ -0,0 +1,48 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 5, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": 0, + "tokens": [ + { + "id": 926, + "logprob": -4.3554688, + "special": false, + "text": " To" + }, + { + "id": 18295, + "logprob": -7.7734375, + "special": false, + "text": " sell" + }, + { + "id": 7868, + "logprob": -3.9257812, + "special": false, + "text": " things" + }, + { + "id": 260, + "logprob": -2.4179688, + "special": false, + "text": "." + }, + { + "id": 1, + "logprob": 0.0, + "special": true, + "text": "" + } + ] + }, + "generated_text": "To sell things." +} diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json new file mode 100644 index 00000000..5cacf3e9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json @@ -0,0 +1,79 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": 0, + "tokens": [ + { + "id": 16017, + "logprob": 0.0, + "special": false, + "text": " blue" + }, + { + "id": 20495, + "logprob": 0.0, + "special": false, + "text": " sky" + }, + { + "id": 259, + "logprob": -0.4716797, + "special": false, + "text": " " + }, + { + "id": 261, + "logprob": -0.044677734, + "special": false, + "text": "," + }, + { + "id": 35622, + "logprob": -0.79589844, + "special": false, + "text": " cloud" + }, + { + "id": 263, + "logprob": -1.2958984, + "special": false, + "text": "s" + }, + { + "id": 305, + "logprob": 0.0, + "special": false, + "text": " and" + }, + { + "id": 35622, + "logprob": -1.1630859, + "special": false, + "text": " cloud" + }, + { + "id": 263, + "logprob": 0.0, + "special": false, + "text": "s" + }, + { + "id": 1, + "logprob": 0.0, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "Why is the sky blue?blue sky, clouds and clouds" +} diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json new file mode 100644 index 00000000..c0834ae1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json @@ -0,0 +1,218 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 6, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 259, + "logprob": -1.3798828, + "special": false, + "text": " " + }, + { + "id": 39261, + "logprob": -0.36328125, + "special": false, + "text": "Because" + }, + { + "id": 609, + "logprob": -1.0947266, + "special": false, + "text": " it" + }, + { + "id": 339, + "logprob": -0.8286133, + "special": false, + "text": " is" + }, + { + "id": 16017, + "logprob": -1.6826172, + "special": false, + "text": " blue" + }, + { + "id": 1, + "logprob": -0.7290039, + "special": true, + "text": "" + } + ] + }, + "generated_text": "Because it is blue" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 6, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 259, + "logprob": -1.3789062, + "special": false, + "text": " " + }, + { + "id": 39261, + "logprob": -0.36279297, + "special": false, + "text": "Because" + }, + { + "id": 609, + "logprob": -1.0966797, + "special": false, + "text": " it" + }, + { + "id": 339, + "logprob": -0.8276367, + "special": false, + "text": " is" + }, + { + "id": 16017, + "logprob": -1.6845703, + "special": false, + "text": " blue" + }, + { + "id": 1, + "logprob": -0.72753906, + "special": true, + "text": "" + } + ] + }, + "generated_text": "Because it is blue" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 6, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 259, + "logprob": -1.3789062, + "special": false, + "text": " " + }, + { + "id": 39261, + "logprob": -0.36279297, + "special": false, + "text": "Because" + }, + { + "id": 609, + "logprob": -1.0966797, + "special": false, + "text": " it" + }, + { + "id": 339, + "logprob": -0.8276367, + "special": false, + "text": " is" + }, + { + "id": 16017, + "logprob": -1.6845703, + "special": false, + "text": " blue" + }, + { + "id": 1, + "logprob": -0.72753906, + "special": true, + "text": "" + } + ] + }, + "generated_text": "Because it is blue" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 6, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 259, + "logprob": -1.3789062, + "special": false, + "text": " " + }, + { + "id": 39261, + "logprob": -0.36279297, + "special": false, + "text": "Because" + }, + { + "id": 609, + "logprob": -1.0966797, + "special": false, + "text": " it" + }, + { + "id": 339, + "logprob": -0.8276367, + "special": false, + "text": " is" + }, + { + "id": 16017, + "logprob": -1.6845703, + "special": false, + "text": " blue" + }, + { + "id": 1, + "logprob": -0.72753906, + "special": true, + "text": "" + } + ] + }, + "generated_text": "Because it is blue" + } +] diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox.json b/integration-tests/models/__snapshots__/test_neox/test_neox.json new file mode 100644 index 00000000..2abc27e1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox.json @@ -0,0 +1,113 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1992188, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8984375, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.0976562, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14562988, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26733398, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.86279297, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.94921875, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1835938, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.074035645, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.86376953, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.2070312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4365234, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.109375, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -0.93408203, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.8808594, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" +} diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json new file mode 100644 index 00000000..f37f0d8e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json @@ -0,0 +1,454 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + } +] diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json new file mode 100644 index 00000000..25cdf6d7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json @@ -0,0 +1,163 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.4179688, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1542969, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.359375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.006038666, + "text": "e" + }, + { + "id": 13, + "logprob": -7.328125, + "text": "," + }, + { + "id": 285, + "logprob": -0.3173828, + "text": " and" + }, + { + "id": 752, + "logprob": -2.0625, + "text": " what" + }, + { + "id": 434, + "logprob": -5.7734375, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.74072266, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.5898438, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.2949219, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.40625, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1113281, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008056641, + "text": "?" + }, + { + "id": 0, + "logprob": -2.3300781, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.28125, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.5878906, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5449219, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.05038452, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.002292633, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.3828278e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0010242462, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.090270996, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12719727, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.016571045, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.43432617, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" +} diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json new file mode 100644 index 00000000..0b38e701 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json @@ -0,0 +1,654 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.4179688, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1542969, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.359375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.006038666, + "text": "e" + }, + { + "id": 13, + "logprob": -7.328125, + "text": "," + }, + { + "id": 285, + "logprob": -0.3173828, + "text": " and" + }, + { + "id": 752, + "logprob": -2.0625, + "text": " what" + }, + { + "id": 434, + "logprob": -5.7734375, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.74072266, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.5898438, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.2949219, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.40625, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1113281, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008056641, + "text": "?" + }, + { + "id": 0, + "logprob": -2.3300781, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.28125, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.5878906, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5498047, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.04815674, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.002313614, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.2636185e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0010147095, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.0859375, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12609863, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.016601562, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.38256836, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.04904175e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0009560585, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.08557129, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12084961, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.4025879, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + } +] diff --git a/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json new file mode 100644 index 00000000..6090e2c9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json @@ -0,0 +1,60 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2099609, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2451172, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.3322754, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19213867, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030151367, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" +} diff --git a/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json new file mode 100644 index 00000000..3e9af12e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json @@ -0,0 +1,242 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2119141, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2480469, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.33203125, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19250488, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030166626, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2119141, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2480469, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.33203125, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19250488, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030166626, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2119141, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2480469, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.33203125, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19250488, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030166626, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2099609, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2451172, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.3322754, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19213867, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030151367, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + } +] diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json new file mode 100644 index 00000000..a4c34a10 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -0,0 +1,39 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": { + "format": "celsius", + "location": "Brooklyn" + }, + "description": null, + "name": "get_current_weather" + }, + "id": 0, + "type": "function" + } + ] + }, + "usage": null + } + ], + "created": 1712782670, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native", + "usage": { + "completion_tokens": 37, + "prompt_tokens": 524, + "total_tokens": 561 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json new file mode 100644 index 00000000..04bcdc4e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -0,0 +1,39 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": { + "format": "celsius", + "location": "Brooklyn" + }, + "description": null, + "name": "get_current_weather" + }, + "id": 0, + "type": "function" + } + ] + }, + "usage": null + } + ], + "created": 1712787937, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native", + "usage": { + "completion_tokens": 37, + "prompt_tokens": 524, + "total_tokens": 561 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json new file mode 100644 index 00000000..603c90af --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -0,0 +1,39 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": { + "format": "celsius", + "location": "New York, NY" + }, + "description": null, + "name": "get_current_weather" + }, + "id": 0, + "type": "function" + } + ] + }, + "usage": null + } + ], + "created": 1712852394, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native", + "usage": { + "completion_tokens": 48, + "prompt_tokens": 320, + "total_tokens": 368 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json new file mode 100644 index 00000000..0cd3c67f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": { + "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." + }, + "description": null, + "name": "notify_error" + }, + "id": 0, + "type": "function" + } + ] + }, + "usage": null + } + ], + "created": 1712852597, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.5-native", + "usage": { + "completion_tokens": 39, + "prompt_tokens": 496, + "total_tokens": 535 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json new file mode 100644 index 00000000..f72a5d38 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -0,0 +1,27 @@ +{ + "choices": [ + { + "delta": { + "content": null, + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "", + "name": null + }, + "id": "", + "index": 0, + "type": "function" + } + }, + "finish_reason": "eos_token", + "index": 0, + "logprobs": null + } + ], + "created": 1712788218, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.1-native" +} diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py new file mode 100644 index 00000000..bdcbdc78 --- /dev/null +++ b/integration-tests/models/test_bloom_560m.py @@ -0,0 +1,64 @@ +import pytest + + +@pytest.fixture(scope="module") +def bloom_560_handle(launcher): + with launcher("bigscience/bloom-560m") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def bloom_560(bloom_560_handle): + await bloom_560_handle.health(240) + return bloom_560_handle.client + + +@pytest.mark.asyncio +async def test_bloom_560m(bloom_560, response_snapshot): + response = await bloom_560.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + top_p=0.9, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_bloom_560m_all_params(bloom_560, response_snapshot): + response = await bloom_560.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot): + responses = await generate_load( + bloom_560, + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py new file mode 100644 index 00000000..3995f9e5 --- /dev/null +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.fixture(scope="module") +def bloom_560m_sharded_handle(launcher): + with launcher("bigscience/bloom-560m", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def bloom_560m_sharded(bloom_560m_sharded_handle): + await bloom_560m_sharded_handle.health(240) + return bloom_560m_sharded_handle.client + + +@pytest.mark.asyncio +async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): + response = await bloom_560m_sharded.generate( + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + top_p=0.9, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_bloom_560m_sharded_load( + bloom_560m_sharded, generate_load, response_snapshot +): + responses = await generate_load( + bloom_560m_sharded, + "Pour déguster un ortolan, il faut tout d'abord", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_chat_llama.py b/integration-tests/models/test_chat_llama.py new file mode 100644 index 00000000..10df6dbd --- /dev/null +++ b/integration-tests/models/test_chat_llama.py @@ -0,0 +1,43 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_chat_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_chat(flash_llama_chat_handle): + await flash_llama_chat_handle.health(300) + return flash_llama_chat_handle.client + + +@pytest.mark.private +async def test_flash_llama_simple(flash_llama_chat, response_snapshot): + response = await flash_llama_chat.chat( + max_tokens=100, + seed=1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + print(repr(response.choices[0].message.content)) + assert ( + response.choices[0].message.content + == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas" + ) + assert response == response_snapshot diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py new file mode 100644 index 00000000..cafa8ea6 --- /dev/null +++ b/integration-tests/models/test_completion_prompts.py @@ -0,0 +1,109 @@ +import pytest +import requests +import json +from aiohttp import ClientSession + +from text_generation.types import ( + Completion, +) + + +@pytest.fixture(scope="module") +def flash_llama_completion_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_completion(flash_llama_completion_handle): + await flash_llama_completion_handle.health(300) + return flash_llama_completion_handle.client + + +# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience +# method for it. Instead, we use the `requests` library to make the HTTP request directly. + + +def test_flash_llama_completion_single_prompt( + flash_llama_completion, response_snapshot +): + response = requests.post( + f"{flash_llama_completion.base_url}/v1/completions", + json={ + "model": "tgi", + "prompt": "Say this is a test", + "max_tokens": 5, + "seed": 0, + }, + headers=flash_llama_completion.headers, + stream=False, + ) + response = response.json() + assert len(response["choices"]) == 1 + + assert response == response_snapshot + + +def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): + response = requests.post( + f"{flash_llama_completion.base_url}/v1/completions", + json={ + "model": "tgi", + "prompt": ["Say", "this", "is", "a"], + "max_tokens": 10, + "seed": 0, + }, + headers=flash_llama_completion.headers, + stream=False, + ) + response = response.json() + assert len(response["choices"]) == 4 + + all_indexes = [choice["index"] for choice in response["choices"]] + all_indexes.sort() + assert all_indexes == [0, 1, 2, 3] + + assert response == response_snapshot + + +async def test_flash_llama_completion_many_prompts_stream( + flash_llama_completion, response_snapshot +): + request = { + "model": "tgi", + "prompt": [ + "What color is the sky?", + "Is water wet?", + "What is the capital of France?", + "def mai", + ], + "max_tokens": 10, + "seed": 0, + "stream": True, + } + + url = f"{flash_llama_completion.base_url}/v1/completions" + + chunks = [] + async with ClientSession(headers=flash_llama_completion.headers) as session: + async with session.post(url, json=request) as response: + # iterate over the stream + async for chunk in response.content.iter_any(): + # remove "data:" + chunk = chunk.decode().split("\n\n") + # remove "data:" if present + chunk = [c.replace("data:", "") for c in chunk] + # remove empty strings + chunk = [c for c in chunk if c] + # parse json + chunk = [json.loads(c) for c in chunk] + + for c in chunk: + chunks.append(Completion(**c)) + assert "choices" in c + assert 0 <= c["choices"][0]["index"] <= 4 + + assert response.status == 200 + assert chunks == response_snapshot diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py new file mode 100644 index 00000000..ead918c3 --- /dev/null +++ b/integration-tests/models/test_flash_awq.py @@ -0,0 +1,70 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_awq_handle(launcher): + with launcher( + "abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", + num_shard=1, + quantize="awq", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_awq(flash_llama_awq_handle): + await flash_llama_awq_handle.health(300) + return flash_llama_awq_handle.client + + +@pytest.mark.asyncio +async def test_flash_llama_awq(flash_llama_awq, response_snapshot): + response = await flash_llama_awq.generate( + "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): + response = await flash_llama_awq.generate( + "What is Deep Learning?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): + responses = await generate_load( + flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all( + [ + r.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + for r in responses + ] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py new file mode 100644 index 00000000..a83614ac --- /dev/null +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -0,0 +1,51 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_awq_handle_sharded(launcher): + with launcher( + "abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", + num_shard=2, + quantize="awq", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): + await flash_llama_awq_handle_sharded.health(300) + return flash_llama_awq_handle_sharded.client + + +@pytest.mark.asyncio +async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): + response = await flash_llama_awq_sharded.generate( + "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_llama_awq_load_sharded( + flash_llama_awq_sharded, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all( + [ + r.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + for r in responses + ] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_falcon.py b/integration-tests/models/test_flash_falcon.py new file mode 100644 index 00000000..eac91984 --- /dev/null +++ b/integration-tests/models/test_flash_falcon.py @@ -0,0 +1,65 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_falcon_handle(launcher): + with launcher("tiiuae/falcon-7b", trust_remote_code=True) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_falcon(flash_falcon_handle): + await flash_falcon_handle.health(300) + return flash_falcon_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_falcon(flash_falcon, response_snapshot): + response = await flash_falcon.generate( + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_falcon_all_params(flash_falcon, response_snapshot): + response = await flash_falcon.generate( + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot): + responses = await generate_load( + flash_falcon, + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py new file mode 100644 index 00000000..7ab43111 --- /dev/null +++ b/integration-tests/models/test_flash_gemma.py @@ -0,0 +1,58 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma_handle(launcher): + with launcher("google/gemma-2b", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma(flash_gemma_handle): + await flash_gemma_handle.health(300) + return flash_gemma_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_all_params(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): + responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py new file mode 100644 index 00000000..8ac5f5a1 --- /dev/null +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -0,0 +1,64 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma_gptq_handle(launcher): + with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma_gptq(flash_gemma_gptq_handle): + await flash_gemma_gptq_handle.health(300) + return flash_gemma_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): + response = await flash_gemma_gptq.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq_all_params( + flash_gemma_gptq, ignore_logprob_response_snapshot +): + response = await flash_gemma_gptq.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq_load( + flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot +): + responses = await generate_load( + flash_gemma_gptq, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == ignore_logprob_response_snapshot diff --git a/integration-tests/models/test_flash_gpt2.py b/integration-tests/models/test_flash_gpt2.py new file mode 100644 index 00000000..0c7977d0 --- /dev/null +++ b/integration-tests/models/test_flash_gpt2.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gpt2_handle(launcher): + with launcher("openai-community/gpt2", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gpt2(flash_gpt2_handle): + await flash_gpt2_handle.health(300) + return flash_gpt2_handle.client + + +@pytest.mark.asyncio +async def test_flash_gpt2(flash_gpt2, response_snapshot): + response = await flash_gpt2.generate( + "What is deep learning?", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot): + responses = await generate_load( + flash_gpt2, + "What is deep learning?", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert all( + [text == generated_texts[0] for text in generated_texts] + ), generated_texts + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_grammar_llama.py b/integration-tests/models/test_flash_grammar_llama.py new file mode 100644 index 00000000..ce1cf787 --- /dev/null +++ b/integration-tests/models/test_flash_grammar_llama.py @@ -0,0 +1,150 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_grammar_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar(flash_llama_grammar_handle): + await flash_llama_grammar_handle.health(300) + return flash_llama_grammar_handle.client + + +@pytest.mark.asyncio +async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Whats Googles DNS", + max_new_tokens=10, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + }, + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "42.1.1.101" + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "info: david holtz like trees and has two cats. ", + max_new_tokens=100, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Json, # "json" + "value": json.dumps( + { + "type": "object", + "$id": "https://example.com/person.schema.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Person", + "properties": { + "firstName": { + "type": "string", + "description": "The person'''s first name.", + }, + "lastName": { + "type": "string", + "description": "The person'''s last name.", + }, + "hobby": { + "description": "The person'''s hobby.", + "type": "string", + }, + "numCats": { + "description": "The number of cats the person has.", + "type": "integer", + "minimum": 0, + }, + }, + "required": ["firstName", "lastName", "hobby", "numCats"], + } + ), + }, + ) + + assert response.details.generated_tokens == 30 + assert ( + response.generated_text + == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}' + ) + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_flash_llama_grammar_load( + flash_llama_grammar, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_grammar, + "name: david. email: ", + max_new_tokens=10, + n=4, + stop_sequences=[".com"], + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, + ) + + assert len(responses) == 4 + + expected = "123456@gmail.com" + + for response in responses: + assert response.generated_text == expected + + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot + + +# this is the same as the above test, but only fires off a single request +# this is only to ensure that the parallel and single inference produce the same result +@pytest.mark.skip +@pytest.mark.asyncio +async def test_flash_llama_grammar_single_load_instance( + flash_llama_grammar, generate_load, response_snapshot +): + response = await flash_llama_grammar.generate( + "name: david. email: ", + max_new_tokens=10, + stop_sequences=[".com"], + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, + ) + + # assert response.details.generated_tokens == 30 + assert response.generated_text == "123456@gmail.com" + + assert response == response_snapshot diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py new file mode 100644 index 00000000..c69314ff --- /dev/null +++ b/integration-tests/models/test_flash_llama.py @@ -0,0 +1,58 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_handle(launcher): + with launcher("huggingface/llama-7b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama(flash_llama_handle): + await flash_llama_handle.health(300) + return flash_llama_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama(flash_llama, response_snapshot): + response = await flash_llama.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_all_params(flash_llama, response_snapshot): + response = await flash_llama.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_load(flash_llama, generate_load, response_snapshot): + responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py new file mode 100644 index 00000000..18319f60 --- /dev/null +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -0,0 +1,73 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_exl2_handle(launcher): + with launcher( + "turboderp/Llama-3-8B-Instruct-exl2", + revision="2.5bpw", + # Set max input length to avoid OOM due to extremely large + # scratch buffer. + max_input_length=1024, + num_shard=1, + quantize="exl2", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_exl2(flash_llama_exl2_handle): + await flash_llama_exl2_handle.health(300) + return flash_llama_exl2_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): + response = await flash_llama_exl2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2_all_params( + flash_llama_exl2, ignore_logprob_response_snapshot +): + response = await flash_llama_exl2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert ( + response.generated_text == 'Test request. The server responds with a "200 OK"' + ) + assert response == ignore_logprob_response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_exl2_load( + flash_llama_exl2, generate_load, ignore_logprob_response_snapshot +): + responses = await generate_load( + flash_llama_exl2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == ignore_logprob_response_snapshot diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py new file mode 100644 index 00000000..b87f054b --- /dev/null +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -0,0 +1,61 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_gptq_handle(launcher): + with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_gptq(flash_llama_gptq_handle): + await flash_llama_gptq_handle.health(300) + return flash_llama_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_load( + flash_llama_gptq, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_gptq, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py new file mode 100644 index 00000000..27db5665 --- /dev/null +++ b/integration-tests/models/test_flash_medusa.py @@ -0,0 +1,64 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_medusa_handle(launcher): + with launcher( + "FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_medusa(flash_medusa_handle): + await flash_medusa_handle.health(300) + return flash_medusa_handle.client + + +@pytest.mark.asyncio +async def test_flash_medusa_simple(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_medusa_all_params(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "What is Deep Learning?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): + responses = await generate_load( + flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + assert ( + responses[0].generated_text == "\nDeep learning is a subset of machine learning" + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py new file mode 100644 index 00000000..52b51928 --- /dev/null +++ b/integration-tests/models/test_flash_mistral.py @@ -0,0 +1,61 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_mistral_handle(launcher): + with launcher("mistralai/Mistral-7B-Instruct-v0.1") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_mistral(flash_mistral_handle): + await flash_mistral_handle.health(300) + return flash_mistral_handle.client + + +@pytest.mark.asyncio +async def test_flash_mistral(flash_mistral, response_snapshot): + response = await flash_mistral.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == ": Let n = 10 - 1" + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mistral_all_params(flash_mistral, response_snapshot): + response = await flash_mistral.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): + responses = await generate_load( + flash_mistral, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == ": Let n = 10 - 1" + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py new file mode 100644 index 00000000..0289c61d --- /dev/null +++ b/integration-tests/models/test_flash_neox.py @@ -0,0 +1,46 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_neox_handle(launcher): + with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_neox(flash_neox_handle): + await flash_neox_handle.health(300) + return flash_neox_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_flash_neox(flash_neox, response_snapshot): + response = await flash_neox.generate( + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): + responses = await generate_load( + flash_neox, + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert all( + [text == generated_texts[0] for text in generated_texts] + ), generated_texts + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_neox_sharded.py b/integration-tests/models/test_flash_neox_sharded.py new file mode 100644 index 00000000..8a491915 --- /dev/null +++ b/integration-tests/models/test_flash_neox_sharded.py @@ -0,0 +1,40 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_neox_sharded_handle(launcher): + with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_neox_sharded(flash_neox_sharded_handle): + await flash_neox_sharded_handle.health(300) + return flash_neox_sharded_handle.client + + +@pytest.mark.asyncio +async def test_flash_neox(flash_neox_sharded, response_snapshot): + response = await flash_neox_sharded.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot): + responses = await generate_load( + flash_neox_sharded, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py new file mode 100644 index 00000000..d4e83c9f --- /dev/null +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -0,0 +1,39 @@ +import pytest +import requests +import io +import base64 + + +@pytest.fixture(scope="module") +def flash_pali_gemma_handle(launcher): + with launcher( + "google/paligemma-3b-pt-224", + num_shard=1, + revision="float16", + max_input_length=4000, + max_total_tokens=4096, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_pali_gemma(flash_pali_gemma_handle): + await flash_pali_gemma_handle.health(300) + return flash_pali_gemma_handle.client + + +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): + cow = get_cow_beach() + inputs = f"![]({cow})Where is the cow standing?\n" + response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) + + assert response.generated_text == "beach" + assert response == response_snapshot diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py new file mode 100644 index 00000000..9d6ca566 --- /dev/null +++ b/integration-tests/models/test_flash_phi.py @@ -0,0 +1,60 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_phi_handle(launcher): + with launcher("microsoft/phi-2", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_phi(flash_phi_handle): + await flash_phi_handle.health(300) + return flash_phi_handle.client + + +@pytest.mark.asyncio +async def test_flash_phi(flash_phi, response_snapshot): + response = await flash_phi.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == ': {request}")\n response = self' + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_phi_all_params(flash_phi, response_snapshot): + response = await flash_phi.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["network"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 6 + assert response.generated_text == "Test request to send data over a network" + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): + responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == ': {request}")\n response = self' + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_qwen2.py b/integration-tests/models/test_flash_qwen2.py new file mode 100644 index 00000000..2963aeb4 --- /dev/null +++ b/integration-tests/models/test_flash_qwen2.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_handle(launcher): + with launcher("Qwen/Qwen1.5-0.5B") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2(flash_qwen2_handle): + await flash_qwen2_handle.health(300) + return flash_qwen2_handle.client + + +@pytest.mark.asyncio +async def test_flash_qwen2(flash_qwen2, response_snapshot): + response = await flash_qwen2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "\n# Create a request\nrequest = requests.get" + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): + response = await flash_qwen2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): + responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == "\n# Create a request\nrequest = requests.get" + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py new file mode 100644 index 00000000..0f005f15 --- /dev/null +++ b/integration-tests/models/test_flash_santacoder.py @@ -0,0 +1,37 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_santacoder_handle(launcher): + with launcher("bigcode/santacoder") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_santacoder(flash_santacoder_handle): + await flash_santacoder_handle.health(300) + return flash_santacoder_handle.client + + +@pytest.mark.asyncio +async def test_flash_santacoder(flash_santacoder, response_snapshot): + response = await flash_santacoder.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_santacoder_load( + flash_santacoder, generate_load, response_snapshot +): + responses = await generate_load( + flash_santacoder, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py new file mode 100644 index 00000000..64e8b27c --- /dev/null +++ b/integration-tests/models/test_flash_starcoder.py @@ -0,0 +1,53 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_starcoder_handle(launcher): + with launcher("bigcode/starcoder", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder(flash_starcoder_handle): + await flash_starcoder_handle.health(300) + return flash_starcoder_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder(flash_starcoder, response_snapshot): + response = await flash_starcoder.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot): + response = await flash_starcoder.generate( + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot): + responses = await generate_load( + flash_starcoder, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder2.py b/integration-tests/models/test_flash_starcoder2.py new file mode 100644 index 00000000..ea665b6c --- /dev/null +++ b/integration-tests/models/test_flash_starcoder2.py @@ -0,0 +1,55 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_starcoder2_handle(launcher): + with launcher("bigcode/starcoder2-3b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder2(flash_starcoder2_handle): + await flash_starcoder2_handle.health(300) + return flash_starcoder2_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_load( + flash_starcoder2, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder2, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py new file mode 100644 index 00000000..329158b7 --- /dev/null +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -0,0 +1,57 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_starcoder_gptq_handle(launcher): + with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder_gptq(flash_starcoder_gptq_handle): + await flash_starcoder_gptq_handle.health(300) + return flash_starcoder_gptq_handle.client + + +@pytest.mark.asyncio +async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): + response = await flash_starcoder_gptq.generate( + "def geometric_mean(L: List[float]):", + max_new_tokens=20, + decoder_input_details=True, + ) + assert response.details.generated_tokens == 20 + assert response == generous_response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder_gptq_default_params( + flash_starcoder_gptq, generous_response_snapshot +): + response = await flash_starcoder_gptq.generate( + "def geometric_mean(L: List[float]):", + max_new_tokens=20, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + assert response.details.generated_tokens == 20 + assert response == generous_response_snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder_gptq_load( + flash_starcoder_gptq, generate_load, generous_response_snapshot +): + responses = await generate_load( + flash_starcoder_gptq, + "def geometric_mean(L: List[float]):", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == generous_response_snapshot diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py new file mode 100644 index 00000000..ce5da8a9 --- /dev/null +++ b/integration-tests/models/test_grammar_llama.py @@ -0,0 +1,70 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def non_flash_llama_grammar_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + num_shard=1, + disable_grammar_support=False, + use_flash_attention=False, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def non_flash_llama_grammar(non_flash_llama_grammar_handle): + await non_flash_llama_grammar_handle.health(300) + return non_flash_llama_grammar_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot): + response = await non_flash_llama_grammar.generate( + "info: david holtz like trees and has two cats. ", + max_new_tokens=100, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Json, + "value": json.dumps( + { + "type": "object", + "$id": "https://example.com/person.schema.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Person", + "properties": { + "firstName": { + "type": "string", + "description": "The person'''s first name.", + }, + "lastName": { + "type": "string", + "description": "The person'''s last name.", + }, + "hobby": { + "description": "The person'''s hobby.", + "type": "string", + }, + "numCats": { + "description": "The number of cats the person has.", + "type": "integer", + "minimum": 0, + }, + }, + "required": ["firstName", "lastName", "hobby", "numCats"], + } + ), + }, + ) + + assert response.details.generated_tokens == 30 + assert ( + response.generated_text + == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}' + ) + assert response == response_snapshot diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py new file mode 100644 index 00000000..aeeaffa1 --- /dev/null +++ b/integration-tests/models/test_idefics.py @@ -0,0 +1,62 @@ +import pytest +import base64 + + +@pytest.fixture(scope="module") +def idefics_handle(launcher): + with launcher( + "HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def idefics(idefics_handle): + await idefics_handle.health(300) + return idefics_handle.client + + +# TODO fix the server parsser to count inline image tokens correctly +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.mark.asyncio +async def test_idefics(idefics, response_snapshot): + chicken = get_chicken() + response = await idefics.generate( + f"User:![]({chicken})Can you tell me a very short story based on the image?", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text == " \nAssistant: A rooster stands" + ), f"{repr(response.generated_text)}" + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_idefics_load(idefics, generate_load, response_snapshot): + chicken = get_chicken() + responses = await generate_load( + idefics, + f"User:![]({chicken})Can you tell me a very short story based on the image?", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert ( + generated_texts[0] == " \nAssistant: A rooster stands" + ), f"{response.generated_text}" + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py new file mode 100644 index 00000000..d34cce34 --- /dev/null +++ b/integration-tests/models/test_idefics2.py @@ -0,0 +1,81 @@ +import pytest +import base64 + + +# TODO fix the server parsser to count inline image tokens correctly +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture(scope="module") +def flash_idefics2_next_handle(launcher): + with launcher( + "HuggingFaceM4/idefics2-8b", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_idefics2_next(flash_idefics2_next_handle): + await flash_idefics2_next_handle.health(300) + return flash_idefics2_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot): + chicken = get_chicken() + response = await flash_idefics2_next.generate( + f"User:![]({chicken})Write me a short story \nAssistant:", + max_new_tokens=10, + ) + assert ( + response.generated_text == " A chicken is sitting on a pile of money." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot): + response = await flash_idefics2_next.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics2_next_load( + flash_idefics2_next, generate_load, response_snapshot +): + chicken = get_chicken() + responses = await generate_load( + flash_idefics2_next, + f"User:![]({chicken})Write me a short story \nAssistant:", + max_new_tokens=10, + n=4, + ) + generated_texts = [r.generated_text for r in responses] + assert generated_texts[0] == " A chicken is sitting on a pile of money." + assert len(generated_texts) == 4 + assert all([r.generated_text == generated_texts[0] for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py new file mode 100644 index 00000000..f5b290b1 --- /dev/null +++ b/integration-tests/models/test_llava_next.py @@ -0,0 +1,84 @@ +import pytest +import base64 + + +# TODO fix the server parsser to count inline image tokens correctly +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture(scope="module") +def flash_llava_next_handle(launcher): + with launcher( + "llava-hf/llava-v1.6-mistral-7b-hf", + num_shard=4, + max_input_length=4000, + max_total_tokens=4096, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llava_next(flash_llava_next_handle): + await flash_llava_next_handle.health(300) + return flash_llava_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): + chicken = get_chicken() + response = await flash_llava_next.generate( + f"User:![]({chicken})Can you tell me a very short story based on the image?", + max_new_tokens=10, + ) + assert ( + response.generated_text == "\n\nOnce upon a time, there was a" + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): + response = await flash_llava_next.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 6 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llava_next_load( + flash_llava_next, generate_load, response_snapshot +): + chicken = get_chicken() + responses = await generate_load( + flash_llava_next, + f"User:![]({chicken})Can you tell me a very short story based on the image?", + max_new_tokens=10, + n=4, + ) + generated_texts = [r.generated_text for r in responses] + assert generated_texts[0] == "\n\nOnce upon a time, there was a" + assert len(generated_texts) == 4 + assert all([r.generated_text == generated_texts[0] for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py new file mode 100644 index 00000000..bf3701b4 --- /dev/null +++ b/integration-tests/models/test_mamba.py @@ -0,0 +1,65 @@ +import pytest + + +@pytest.fixture(scope="module") +def fused_kernel_mamba_handle(launcher): + with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def fused_kernel_mamba(fused_kernel_mamba_handle): + await fused_kernel_mamba_handle.health(300) + return fused_kernel_mamba_handle.client + + +@pytest.mark.asyncio +async def test_mamba(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "What is Deep Learning?", max_new_tokens=10 + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "\n\nDeep learning is a new type of machine" + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "blue, red, yellow, ", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "blue, red, yellow, \nand blue colors. A number of different color" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_mamba_load( + fused_kernel_mamba, generate_load, generous_response_snapshot +): + responses = await generate_load( + fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" + + assert responses == generous_response_snapshot diff --git a/integration-tests/models/test_mpt.py b/integration-tests/models/test_mpt.py new file mode 100644 index 00000000..d58a8c5a --- /dev/null +++ b/integration-tests/models/test_mpt.py @@ -0,0 +1,48 @@ +import pytest + + +@pytest.fixture(scope="module") +def mpt_sharded_handle(launcher): + with launcher("mosaicml/mpt-7b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def mpt_sharded(mpt_sharded_handle): + await mpt_sharded_handle.health(300) + return mpt_sharded_handle.client + + +@pytest.mark.asyncio +async def test_mpt(mpt_sharded, response_snapshot): + response = await mpt_sharded.generate( + "What is Deep Learning?", + max_new_tokens=17, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 17 + assert ( + response.generated_text + == " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_mpt_load(mpt_sharded, generate_load, response_snapshot): + responses = await generate_load( + mpt_sharded, + "What is Deep Learning?", + max_new_tokens=17, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert ( + responses[0].generated_text + == " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py new file mode 100644 index 00000000..c877056a --- /dev/null +++ b/integration-tests/models/test_mt0_base.py @@ -0,0 +1,64 @@ +import pytest + + +@pytest.fixture(scope="module") +def mt0_base_handle(launcher): + with launcher("bigscience/mt0-base") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def mt0_base(mt0_base_handle): + await mt0_base_handle.health(300) + return mt0_base_handle.client + + +@pytest.mark.asyncio +async def test_mt0_base(mt0_base, response_snapshot): + response = await mt0_base.generate( + "Why is the sky blue?", + max_new_tokens=10, + top_p=0.9, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_mt0_base_all_params(mt0_base, response_snapshot): + response = await mt0_base.generate( + "Why is the sky blue?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_mt0_base_load(mt0_base, generate_load, response_snapshot): + responses = await generate_load( + mt0_base, + "Why is the sky blue?", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py new file mode 100644 index 00000000..7b88f86a --- /dev/null +++ b/integration-tests/models/test_neox.py @@ -0,0 +1,48 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_handle(launcher): + with launcher( + "stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox(neox_handle): + await neox_handle.health(300) + return neox_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox, response_snapshot): + response = await neox.generate( + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox, generate_load, response_snapshot): + responses = await generate_load( + neox, + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py new file mode 100644 index 00000000..8cee8765 --- /dev/null +++ b/integration-tests/models/test_neox_sharded.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_sharded_handle(launcher): + with launcher( + "OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox_sharded(neox_sharded_handle): + await neox_sharded_handle.health(300) + return neox_sharded_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox_sharded, response_snapshot): + response = await neox_sharded.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox_sharded, generate_load, response_snapshot): + responses = await generate_load( + neox_sharded, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py new file mode 100644 index 00000000..4b4cfd98 --- /dev/null +++ b/integration-tests/models/test_t5_sharded.py @@ -0,0 +1,39 @@ +import pytest + + +@pytest.fixture(scope="module") +def t5_sharded_handle(launcher): + with launcher("google/flan-t5-xxl", num_shard=4) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def t5_sharded(t5_sharded_handle): + await t5_sharded_handle.health(300) + return t5_sharded_handle.client + + +@pytest.mark.asyncio +async def test_t5_sharded(t5_sharded, response_snapshot): + response = await t5_sharded.generate( + "Please answer the following question. What is the boiling point of Nitrogen?", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot): + responses = await generate_load( + t5_sharded, + "Please answer the following question. What is the boiling point of Nitrogen?", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py new file mode 100644 index 00000000..0af3f66a --- /dev/null +++ b/integration-tests/models/test_tools_llama.py @@ -0,0 +1,259 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_grammar_tools_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle): + await flash_llama_grammar_tools_handle.health(300) + return flash_llama_grammar_tools_handle.client + + +# tools to be used in the following tests +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + }, +] + + +@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == [ + { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "New York, NY"}, + }, + } + ] + assert response == response_snapshot + + +@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_auto( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="auto", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == [ + { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "New York, NY"}, + }, + } + ] + + assert response == response_snapshot + + +@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_choice( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == [ + { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "New York, NY"}, + }, + } + ] + + assert response == response_snapshot + + +@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Paris, France?", + }, + ], + stream=True, + ) + + count = 0 + async for response in responses: + count += 1 + + assert count == 38 + assert response == response_snapshot + + +@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_insufficient_information( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=8, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=False, + ) + + assert responses.choices[0].message.content == None + assert responses.choices[0].message.tool_calls == [ + { + "function": { + "arguments": { + "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." + }, + "description": None, + "name": "notify_error", + }, + "id": 0, + "type": "function", + } + ] + + assert responses == response_snapshot diff --git a/integration-tests/poetry.lock b/integration-tests/poetry.lock new file mode 100644 index 00000000..3af99942 --- /dev/null +++ b/integration-tests/poetry.lock @@ -0,0 +1,1052 @@ +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. + +[[package]] +name = "aiohttp" +version = "3.8.5" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.6" +files = [ + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"}, + {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"}, + {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"}, + {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"}, + {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"}, + {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"}, + {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"}, + {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"}, + {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"}, + {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"}, + {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"}, +] + +[package.dependencies] +aiosignal = ">=1.1.2" +async-timeout = ">=4.0.0a3,<5.0" +attrs = ">=17.3.0" +charset-normalizer = ">=2.0,<4.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns", "cchardet"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + +[[package]] +name = "attrs" +version = "23.1.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=3.7" +files = [ + {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, + {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[docs,tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] + +[[package]] +name = "certifi" +version = "2023.7.22" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, + {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.2.0" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, + {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "colored" +version = "1.4.4" +description = "Simple library for color and formatting to terminal" +optional = false +python-versions = "*" +files = [ + {file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"}, +] + +[[package]] +name = "docker" +version = "6.1.3" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.7" +files = [ + {file = "docker-6.1.3-py3-none-any.whl", hash = "sha256:aecd2277b8bf8e506e484f6ab7aec39abe0038e29fa4a6d3ba86c3fe01844ed9"}, + {file = "docker-6.1.3.tar.gz", hash = "sha256:aa6d17830045ba5ef0168d5eaa34d37beeb113948c413affe1d5991fc11f9a20"}, +] + +[package.dependencies] +packaging = ">=14.0" +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" +websocket-client = ">=0.32.0" + +[package.extras] +ssh = ["paramiko (>=2.4.3)"] + +[[package]] +name = "exceptiongroup" +version = "1.1.3" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, + {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.12.3" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.12.3-py3-none-any.whl", hash = "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb"}, + {file = "filelock-3.12.3.tar.gz", hash = "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.11\""} + +[package.extras] +docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] + +[[package]] +name = "frozenlist" +version = "1.4.0" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, + {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, + {file = "frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ac08e601308e41eb533f232dbf6b7e4cea762f9f84f6357136eed926c15d12c"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d081f13b095d74b67d550de04df1c756831f3b83dc9881c38985834387487f1b"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71932b597f9895f011f47f17d6428252fc728ba2ae6024e13c3398a087c2cdea"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:981b9ab5a0a3178ff413bca62526bb784249421c24ad7381e39d67981be2c326"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e41f3de4df3e80de75845d3e743b3f1c4c8613c3997a912dbf0229fc61a8b963"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6918d49b1f90821e93069682c06ffde41829c346c66b721e65a5c62b4bab0300"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e5c8764c7829343d919cc2dfc587a8db01c4f70a4ebbc49abde5d4b158b007b"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8d0edd6b1c7fb94922bf569c9b092ee187a83f03fb1a63076e7774b60f9481a8"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e29cda763f752553fa14c68fb2195150bfab22b352572cb36c43c47bedba70eb"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:0c7c1b47859ee2cac3846fde1c1dc0f15da6cec5a0e5c72d101e0f83dcb67ff9"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:901289d524fdd571be1c7be054f48b1f88ce8dddcbdf1ec698b27d4b8b9e5d62"}, + {file = "frozenlist-1.4.0-cp310-cp310-win32.whl", hash = "sha256:1a0848b52815006ea6596c395f87449f693dc419061cc21e970f139d466dc0a0"}, + {file = "frozenlist-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:b206646d176a007466358aa21d85cd8600a415c67c9bd15403336c331a10d956"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de343e75f40e972bae1ef6090267f8260c1446a1695e77096db6cfa25e759a95"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad2a9eb6d9839ae241701d0918f54c51365a51407fd80f6b8289e2dfca977cc3"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7bd3b3830247580de99c99ea2a01416dfc3c34471ca1298bccabf86d0ff4dc"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdf1847068c362f16b353163391210269e4f0569a3c166bc6a9f74ccbfc7e839"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38461d02d66de17455072c9ba981d35f1d2a73024bee7790ac2f9e361ef1cd0c"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5a32087d720c608f42caed0ef36d2b3ea61a9d09ee59a5142d6070da9041b8f"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd65632acaf0d47608190a71bfe46b209719bf2beb59507db08ccdbe712f969b"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261b9f5d17cac914531331ff1b1d452125bf5daa05faf73b71d935485b0c510b"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b89ac9768b82205936771f8d2eb3ce88503b1556324c9f903e7156669f521472"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:008eb8b31b3ea6896da16c38c1b136cb9fec9e249e77f6211d479db79a4eaf01"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:490132667476f6781b4c9458298b0c1cddf237488abd228b0b3650e5ecba7467"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:76d4711f6f6d08551a7e9ef28c722f4a50dd0fc204c56b4bcd95c6cc05ce6fbb"}, + {file = "frozenlist-1.4.0-cp311-cp311-win32.whl", hash = "sha256:a02eb8ab2b8f200179b5f62b59757685ae9987996ae549ccf30f983f40602431"}, + {file = "frozenlist-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:515e1abc578dd3b275d6a5114030b1330ba044ffba03f94091842852f806f1c1"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ca265542ca427bf97aed183c1676e2a9c66942e822b14dc6e5f42e038f92a503"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:491e014f5c43656da08958808588cc6c016847b4360e327a62cb308c791bd2d9"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ae5cd0f333f94f2e03aaf140bb762c64783935cc764ff9c82dff626089bebf"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e78fb68cf9c1a6aa4a9a12e960a5c9dfbdb89b3695197aa7064705662515de2"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5655a942f5f5d2c9ed93d72148226d75369b4f6952680211972a33e59b1dfdc"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11b0746f5d946fecf750428a95f3e9ebe792c1ee3b1e96eeba145dc631a9672"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e66d2a64d44d50d2543405fb183a21f76b3b5fd16f130f5c99187c3fb4e64919"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f7bc0fcca81f985f78dd0fa68d2c75abf8272b1f5c323ea4a01a4d7a614efc"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5833593c25ac59ede40ed4de6d67eb42928cca97f26feea219f21d0ed0959b79"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:b826d97e4276750beca7c8f0f1a4938892697a6bcd8ec8217b3312dad6982781"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ceb6ec0a10c65540421e20ebd29083c50e6d1143278746a4ef6bcf6153171eb8"}, + {file = "frozenlist-1.4.0-cp38-cp38-win32.whl", hash = "sha256:2b8bcf994563466db019fab287ff390fffbfdb4f905fc77bc1c1d604b1c689cc"}, + {file = "frozenlist-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:a6c8097e01886188e5be3e6b14e94ab365f384736aa1fca6a0b9e35bd4a30bc7"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6c38721585f285203e4b4132a352eb3daa19121a035f3182e08e437cface44bf"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0c6da9aee33ff0b1a451e867da0c1f47408112b3391dd43133838339e410963"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93ea75c050c5bb3d98016b4ba2497851eadf0ac154d88a67d7a6816206f6fa7f"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa384489fefeb62321b238e64c07ef48398fe80f9e1e6afeff22e140e0850eef"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10ff5faaa22786315ef57097a279b833ecab1a0bfb07d604c9cbb1c4cdc2ed87"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:007df07a6e3eb3e33e9a1fe6a9db7af152bbd8a185f9aaa6ece10a3529e3e1c6"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4f399d28478d1f604c2ff9119907af9726aed73680e5ed1ca634d377abb087"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5374b80521d3d3f2ec5572e05adc94601985cc526fb276d0c8574a6d749f1b3"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ce31ae3e19f3c902de379cf1323d90c649425b86de7bbdf82871b8a2a0615f3d"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7211ef110a9194b6042449431e08c4d80c0481e5891e58d429df5899690511c2"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:556de4430ce324c836789fa4560ca62d1591d2538b8ceb0b4f68fb7b2384a27a"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7645a8e814a3ee34a89c4a372011dcd817964ce8cb273c8ed6119d706e9613e3"}, + {file = "frozenlist-1.4.0-cp39-cp39-win32.whl", hash = "sha256:19488c57c12d4e8095a922f328df3f179c820c212940a498623ed39160bc3c2f"}, + {file = "frozenlist-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:6221d84d463fb110bdd7619b69cb43878a11d51cbb9394ae3105d082d5199167"}, + {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"}, +] + +[[package]] +name = "fsspec" +version = "2023.6.0" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2023.6.0-py3-none-any.whl", hash = "sha256:1cbad1faef3e391fba6dc005ae9b5bdcbf43005c9167ce78c915549c352c869a"}, + {file = "fsspec-2023.6.0.tar.gz", hash = "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"}, +] + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + +[[package]] +name = "huggingface-hub" +version = "0.16.4" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"}, + {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic"] +quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["torch"] +typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] + +[[package]] +name = "idna" +version = "3.4" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, + {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "multidict" +version = "6.0.4" +description = "multidict implementation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, + {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"}, + {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"}, + {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"}, + {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"}, + {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"}, + {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"}, + {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"}, + {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"}, + {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"}, + {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"}, + {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"}, + {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"}, + {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"}, + {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, +] + +[[package]] +name = "packaging" +version = "23.1" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, + {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, +] + +[[package]] +name = "pluggy" +version = "1.3.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, + {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pydantic" +version = "2.6.4" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, + {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.16.3" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.16.3" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.16.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:75b81e678d1c1ede0785c7f46690621e4c6e63ccd9192af1f0bd9d504bbb6bf4"}, + {file = "pydantic_core-2.16.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9c865a7ee6f93783bd5d781af5a4c43dadc37053a5b42f7d18dc019f8c9d2bd1"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:162e498303d2b1c036b957a1278fa0899d02b2842f1ff901b6395104c5554a45"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f583bd01bbfbff4eaee0868e6fc607efdfcc2b03c1c766b06a707abbc856187"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b926dd38db1519ed3043a4de50214e0d600d404099c3392f098a7f9d75029ff8"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:716b542728d4c742353448765aa7cdaa519a7b82f9564130e2b3f6766018c9ec"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4ad7f7ee1a13d9cb49d8198cd7d7e3aa93e425f371a68235f784e99741561f"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bd87f48924f360e5d1c5f770d6155ce0e7d83f7b4e10c2f9ec001c73cf475c99"}, + {file = "pydantic_core-2.16.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0df446663464884297c793874573549229f9eca73b59360878f382a0fc085979"}, + {file = "pydantic_core-2.16.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4df8a199d9f6afc5ae9a65f8f95ee52cae389a8c6b20163762bde0426275b7db"}, + {file = "pydantic_core-2.16.3-cp310-none-win32.whl", hash = "sha256:456855f57b413f077dff513a5a28ed838dbbb15082ba00f80750377eed23d132"}, + {file = "pydantic_core-2.16.3-cp310-none-win_amd64.whl", hash = "sha256:732da3243e1b8d3eab8c6ae23ae6a58548849d2e4a4e03a1924c8ddf71a387cb"}, + {file = "pydantic_core-2.16.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:519ae0312616026bf4cedc0fe459e982734f3ca82ee8c7246c19b650b60a5ee4"}, + {file = "pydantic_core-2.16.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b3992a322a5617ded0a9f23fd06dbc1e4bd7cf39bc4ccf344b10f80af58beacd"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d62da299c6ecb04df729e4b5c52dc0d53f4f8430b4492b93aa8de1f541c4aac"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2acca2be4bb2f2147ada8cac612f8a98fc09f41c89f87add7256ad27332c2fda"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1b662180108c55dfbf1280d865b2d116633d436cfc0bba82323554873967b340"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e7c6ed0dc9d8e65f24f5824291550139fe6f37fac03788d4580da0d33bc00c97"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6b1bb0827f56654b4437955555dc3aeeebeddc47c2d7ed575477f082622c49e"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e56f8186d6210ac7ece503193ec84104da7ceb98f68ce18c07282fcc2452e76f"}, + {file = "pydantic_core-2.16.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:936e5db01dd49476fa8f4383c259b8b1303d5dd5fb34c97de194560698cc2c5e"}, + {file = "pydantic_core-2.16.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:33809aebac276089b78db106ee692bdc9044710e26f24a9a2eaa35a0f9fa70ba"}, + {file = "pydantic_core-2.16.3-cp311-none-win32.whl", hash = "sha256:ded1c35f15c9dea16ead9bffcde9bb5c7c031bff076355dc58dcb1cb436c4721"}, + {file = "pydantic_core-2.16.3-cp311-none-win_amd64.whl", hash = "sha256:d89ca19cdd0dd5f31606a9329e309d4fcbb3df860960acec32630297d61820df"}, + {file = "pydantic_core-2.16.3-cp311-none-win_arm64.whl", hash = "sha256:6162f8d2dc27ba21027f261e4fa26f8bcb3cf9784b7f9499466a311ac284b5b9"}, + {file = "pydantic_core-2.16.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:0f56ae86b60ea987ae8bcd6654a887238fd53d1384f9b222ac457070b7ac4cff"}, + {file = "pydantic_core-2.16.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9bd22a2a639e26171068f8ebb5400ce2c1bc7d17959f60a3b753ae13c632975"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4204e773b4b408062960e65468d5346bdfe139247ee5f1ca2a378983e11388a2"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f651dd19363c632f4abe3480a7c87a9773be27cfe1341aef06e8759599454120"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aaf09e615a0bf98d406657e0008e4a8701b11481840be7d31755dc9f97c44053"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8e47755d8152c1ab5b55928ab422a76e2e7b22b5ed8e90a7d584268dd49e9c6b"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:500960cb3a0543a724a81ba859da816e8cf01b0e6aaeedf2c3775d12ee49cade"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cf6204fe865da605285c34cf1172879d0314ff267b1c35ff59de7154f35fdc2e"}, + {file = "pydantic_core-2.16.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d33dd21f572545649f90c38c227cc8631268ba25c460b5569abebdd0ec5974ca"}, + {file = "pydantic_core-2.16.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:49d5d58abd4b83fb8ce763be7794d09b2f50f10aa65c0f0c1696c677edeb7cbf"}, + {file = "pydantic_core-2.16.3-cp312-none-win32.whl", hash = "sha256:f53aace168a2a10582e570b7736cc5bef12cae9cf21775e3eafac597e8551fbe"}, + {file = "pydantic_core-2.16.3-cp312-none-win_amd64.whl", hash = "sha256:0d32576b1de5a30d9a97f300cc6a3f4694c428d956adbc7e6e2f9cad279e45ed"}, + {file = "pydantic_core-2.16.3-cp312-none-win_arm64.whl", hash = "sha256:ec08be75bb268473677edb83ba71e7e74b43c008e4a7b1907c6d57e940bf34b6"}, + {file = "pydantic_core-2.16.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:b1f6f5938d63c6139860f044e2538baeee6f0b251a1816e7adb6cbce106a1f01"}, + {file = "pydantic_core-2.16.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a1ef6a36fdbf71538142ed604ad19b82f67b05749512e47f247a6ddd06afdc7"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:704d35ecc7e9c31d48926150afada60401c55efa3b46cd1ded5a01bdffaf1d48"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d937653a696465677ed583124b94a4b2d79f5e30b2c46115a68e482c6a591c8a"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9803edf8e29bd825f43481f19c37f50d2b01899448273b3a7758441b512acf8"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:72282ad4892a9fb2da25defeac8c2e84352c108705c972db82ab121d15f14e6d"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f752826b5b8361193df55afcdf8ca6a57d0232653494ba473630a83ba50d8c9"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4384a8f68ddb31a0b0c3deae88765f5868a1b9148939c3f4121233314ad5532c"}, + {file = "pydantic_core-2.16.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4b2bf78342c40b3dc830880106f54328928ff03e357935ad26c7128bbd66ce8"}, + {file = "pydantic_core-2.16.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:13dcc4802961b5f843a9385fc821a0b0135e8c07fc3d9949fd49627c1a5e6ae5"}, + {file = "pydantic_core-2.16.3-cp38-none-win32.whl", hash = "sha256:e3e70c94a0c3841e6aa831edab1619ad5c511199be94d0c11ba75fe06efe107a"}, + {file = "pydantic_core-2.16.3-cp38-none-win_amd64.whl", hash = "sha256:ecdf6bf5f578615f2e985a5e1f6572e23aa632c4bd1dc67f8f406d445ac115ed"}, + {file = "pydantic_core-2.16.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:bda1ee3e08252b8d41fa5537413ffdddd58fa73107171a126d3b9ff001b9b820"}, + {file = "pydantic_core-2.16.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:21b888c973e4f26b7a96491c0965a8a312e13be108022ee510248fe379a5fa23"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be0ec334369316fa73448cc8c982c01e5d2a81c95969d58b8f6e272884df0074"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b5b6079cc452a7c53dd378c6f881ac528246b3ac9aae0f8eef98498a75657805"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ee8d5f878dccb6d499ba4d30d757111847b6849ae07acdd1205fffa1fc1253c"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7233d65d9d651242a68801159763d09e9ec96e8a158dbf118dc090cd77a104c9"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6119dc90483a5cb50a1306adb8d52c66e447da88ea44f323e0ae1a5fcb14256"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:578114bc803a4c1ff9946d977c221e4376620a46cf78da267d946397dc9514a8"}, + {file = "pydantic_core-2.16.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d8f99b147ff3fcf6b3cc60cb0c39ea443884d5559a30b1481e92495f2310ff2b"}, + {file = "pydantic_core-2.16.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4ac6b4ce1e7283d715c4b729d8f9dab9627586dafce81d9eaa009dd7f25dd972"}, + {file = "pydantic_core-2.16.3-cp39-none-win32.whl", hash = "sha256:e7774b570e61cb998490c5235740d475413a1f6de823169b4cf94e2fe9e9f6b2"}, + {file = "pydantic_core-2.16.3-cp39-none-win_amd64.whl", hash = "sha256:9091632a25b8b87b9a605ec0e61f241c456e9248bfdcf7abdf344fdb169c81cf"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:36fa178aacbc277bc6b62a2c3da95226520da4f4e9e206fdf076484363895d2c"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:dcca5d2bf65c6fb591fff92da03f94cd4f315972f97c21975398bd4bd046854a"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a72fb9963cba4cd5793854fd12f4cfee731e86df140f59ff52a49b3552db241"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60cc1a081f80a2105a59385b92d82278b15d80ebb3adb200542ae165cd7d183"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cbcc558401de90a746d02ef330c528f2e668c83350f045833543cd57ecead1ad"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:fee427241c2d9fb7192b658190f9f5fd6dfe41e02f3c1489d2ec1e6a5ab1e04a"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f4cb85f693044e0f71f394ff76c98ddc1bc0953e48c061725e540396d5c8a2e1"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:b29eeb887aa931c2fcef5aa515d9d176d25006794610c264ddc114c053bf96fe"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a425479ee40ff021f8216c9d07a6a3b54b31c8267c6e17aa88b70d7ebd0e5e5b"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5c5cbc703168d1b7a838668998308018a2718c2130595e8e190220238addc96f"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99b6add4c0b39a513d323d3b93bc173dac663c27b99860dd5bf491b240d26137"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75f76ee558751746d6a38f89d60b6228fa174e5172d143886af0f85aa306fd89"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:00ee1c97b5364b84cb0bd82e9bbf645d5e2871fb8c58059d158412fee2d33d8a"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:287073c66748f624be4cef893ef9174e3eb88fe0b8a78dc22e88eca4bc357ca6"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ed25e1835c00a332cb10c683cd39da96a719ab1dfc08427d476bce41b92531fc"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:86b3d0033580bd6bbe07590152007275bd7af95f98eaa5bd36f3da219dcd93da"}, + {file = "pydantic_core-2.16.3.tar.gz", hash = "sha256:1cac689f80a3abab2d3c0048b29eea5751114054f032a941a32de4c852c59cad"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pytest" +version = "7.4.0" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, + {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.7" +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "syrupy" +version = "4.0.1" +description = "Pytest Snapshot Test Utility" +optional = false +python-versions = ">=3.8.1,<4" +files = [ + {file = "syrupy-4.0.1-py3-none-any.whl", hash = "sha256:53d3107cc5e18a5def189c721879cea2cdafdee34b879f602133ca08837d0e4b"}, + {file = "syrupy-4.0.1.tar.gz", hash = "sha256:60e3e94782444e0f978cd3b207de32f6da3199b15a2db32eab02f83cebb63ae8"}, +] + +[package.dependencies] +colored = ">=1.3.92,<2.0.0" +pytest = ">=7.0.0,<8.0.0" + +[[package]] +name = "text-generation" +version = "0.6.1" +description = "Hugging Face Text Generation Python Client" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "text_generation-0.6.1-py3-none-any.whl", hash = "sha256:ebca00587eeabc0f5118f66ee1048bf690bd7735a9a10361c533c31c8c0bf994"}, + {file = "text_generation-0.6.1.tar.gz", hash = "sha256:730e662aa7812f73c08ab953e008e90455f3d046f81efa0ef3de462bd4cf63d9"}, +] + +[package.dependencies] +aiohttp = ">=3.8,<4.0" +huggingface-hub = ">=0.12,<1.0" +pydantic = ">1.10,<3" + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "tqdm" +version = "4.66.1" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, + {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "typing-extensions" +version = "4.7.1" +description = "Backported and Experimental Type Hints for Python 3.7+" +optional = false +python-versions = ">=3.7" +files = [ + {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, + {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, +] + +[[package]] +name = "urllib3" +version = "2.0.4" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.7" +files = [ + {file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"}, + {file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "websocket-client" +version = "1.6.2" +description = "WebSocket client for Python with low level API options" +optional = false +python-versions = ">=3.8" +files = [ + {file = "websocket-client-1.6.2.tar.gz", hash = "sha256:53e95c826bf800c4c465f50093a8c4ff091c7327023b10bfaff40cf1ef170eaa"}, + {file = "websocket_client-1.6.2-py3-none-any.whl", hash = "sha256:ce54f419dfae71f4bdba69ebe65bf7f0a93fe71bc009ad3a010aacc3eebad537"}, +] + +[package.extras] +docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] +optional = ["python-socks", "wsaccel"] +test = ["websockets"] + +[[package]] +name = "yarl" +version = "1.9.2" +description = "Yet another URL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, + {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, + {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, + {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, + {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, + {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, + {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, + {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, + {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, + {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, + {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, + {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, + {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + +[metadata] +lock-version = "2.0" +python-versions = ">=3.9,<3.13" +content-hash = "421fbce065cb1499c666599cf0fd83a5ce8fb3bed09e83c16c3a3d6953b34026" diff --git a/integration-tests/pyproject.toml b/integration-tests/pyproject.toml new file mode 100644 index 00000000..88e9761a --- /dev/null +++ b/integration-tests/pyproject.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "text-generation-integration-tests" +version = "2.0.1" +description = "Text Generation Inference integration tests" +authors = ["Nicolas Patry "] + +[tool.poetry.dependencies] +pydantic = "> 2, < 3" +python = ">=3.9,<3.13" +syrupy = "4.0.1" +text-generation = "^0.6.0" +pytest = "^7.4.0" +pytest-asyncio = "^0.21.1" +docker = "^6.1.3" diff --git a/integration-tests/pytest.ini b/integration-tests/pytest.ini new file mode 100644 index 00000000..bab689d7 --- /dev/null +++ b/integration-tests/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +addopts = --snapshot-warn-unused +asyncio_mode = auto +markers = + private: marks tests as requiring an admin hf token (deselect with '-m "not private"') diff --git a/integration-tests/requirements.txt b/integration-tests/requirements.txt new file mode 100644 index 00000000..3c2ce11b --- /dev/null +++ b/integration-tests/requirements.txt @@ -0,0 +1,35 @@ +aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13" +aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" +annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" +attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" +certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +colored==1.4.4 ; python_version >= "3.9" and python_version < "3.13" +docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13" +exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11" +filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" +frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" +idna==3.4 ; python_version >= "3.9" and python_version < "3.13" +iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" +multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13" +packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" +pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13" +pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13" +pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13" +pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13" +pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13" +pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" +requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" +syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13" +text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" +tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" +websocket-client==1.6.2 ; python_version >= "3.9" and python_version < "3.13" +yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml new file mode 100644 index 00000000..d9abd5b6 --- /dev/null +++ b/launcher/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "text-generation-launcher" +description = "Text Generation Launcher" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +clap = { version = "4.4.5", features = ["derive", "env"] } +ctrlc = { version = "3.4.1", features = ["termination"] } +hf-hub = "0.3.2" +nix = { version = "0.28.0", features = ["signal"] } +once_cell = "1.19.0" +serde = { version = "1.0.188", features = ["derive"] } +serde_json = "1.0.107" +thiserror = "1.0.59" +tracing = "0.1.37" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +bitstream-io = { version = "2.3.0" } + +[dev-dependencies] +float_eq = "1.0.1" +reqwest = { version = "0.11.20", features = ["blocking", "json"] } + +[build-dependencies] +vergen = { version = "8.2.5", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] } diff --git a/launcher/build.rs b/launcher/build.rs new file mode 100644 index 00000000..71d2c0c5 --- /dev/null +++ b/launcher/build.rs @@ -0,0 +1,29 @@ +use std::error::Error; +use vergen::EmitBuilder; + +fn main() -> Result<(), Box> { + // Emit cargo and rustc compile time values + EmitBuilder::builder().all_cargo().all_rustc().emit()?; + + // Try to get the git sha from the local git repository + if EmitBuilder::builder() + .fail_on_error() + .git_sha(false) + .emit() + .is_err() + { + // Unable to get the git sha + if let Ok(sha) = std::env::var("GIT_SHA") { + // Set it from an env var + println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); + } + } + + // Set docker label if present + if let Ok(label) = std::env::var("DOCKER_LABEL") { + // Set it from an env var + println!("cargo:rustc-env=DOCKER_LABEL={label}"); + } + + Ok(()) +} diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs new file mode 100644 index 00000000..08fb301c --- /dev/null +++ b/launcher/src/env_runtime.rs @@ -0,0 +1,56 @@ +use std::fmt; +use std::process::Command; + +pub(crate) struct Env { + cargo_target: &'static str, + cargo_version: &'static str, + git_sha: &'static str, + docker_label: &'static str, + nvidia_env: String, + xpu_env: String, +} + +impl Env { + pub fn new() -> Self { + let nvidia_env = nvidia_smi(); + let xpu_env = xpu_smi(); + + Self { + nvidia_env: nvidia_env.unwrap_or("N/A".to_string()), + xpu_env: xpu_env.unwrap_or("N/A".to_string()), + cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"), + cargo_version: env!("VERGEN_RUSTC_SEMVER"), + git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), + docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), + } + } +} + +impl fmt::Display for Env { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "Runtime environment:")?; + + writeln!(f, "Target: {}", self.cargo_target)?; + writeln!(f, "Cargo version: {}", self.cargo_version)?; + writeln!(f, "Commit sha: {}", self.git_sha)?; + writeln!(f, "Docker label: {}", self.docker_label)?; + writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?; + write!(f, "xpu-smi:\n{}", self.xpu_env)?; + + Ok(()) + } +} + +fn nvidia_smi() -> Option { + let output = Command::new("nvidia-smi").output().ok()?; + let nvidia_smi = String::from_utf8(output.stdout).ok()?; + let output = nvidia_smi.replace('\n', "\n "); + Some(output.trim().to_string()) +} + +fn xpu_smi() -> Option { + let output = Command::new("xpu-smi").arg("discovery").output().ok()?; + let xpu_smi = String::from_utf8(output.stdout).ok()?; + let output = xpu_smi.replace('\n', "\n "); + Some(output.trim().to_string()) +} diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e487e2d1..4eae11cc 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -55,6 +55,10 @@ enum Quantization { /// Should be a drop-in replacement to bitsandbytes with much better performance. /// Kernels are from Eetq, + /// Variable bit quantization. Requires a specific EXL2 quantized model: + /// . Requires exllama2 kernels and does + /// not support tensor parallelism (num_shard > 1). + Exl2, /// 4 bit quantization. Requires a specific GTPQ quantized model: . /// text-generation-inference will use exllama (faster) kernels wherever possible, and use /// triton kernel (wider support) when it's not. @@ -95,6 +99,9 @@ impl std::fmt::Display for Quantization { Quantization::BitsandbytesFP4 => { write!(f, "bitsandbytes-fp4") } + Quantization::Exl2 => { + write!(f, "exl2") + } Quantization::Gptq => { write!(f, "gptq") } @@ -229,7 +236,7 @@ struct Args { max_stop_sequences: usize, /// This is the maximum allowed value for clients to set `top_n_tokens`. - /// `top_n_tokens is used to return information about the the `n` most likely + /// `top_n_tokens` is used to return information about the the `n` most likely /// tokens at each generation step, instead of just the sampled token. This /// information can be used for downstream tasks like for classification or /// ranking. @@ -478,7 +485,7 @@ fn shard_manager( status_sender: mpsc::Sender, shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, - lora_ids: String + lora_ids: String, ) { // Enter shard-manager tracing span let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); @@ -495,8 +502,6 @@ fn shard_manager( let mut shard_args = vec![ "serve".to_string(), model_id, - "--lora-ids".to_string(), - lora_ids, "--uds-path".to_string(), uds_path, "--logger-level".to_string(), @@ -504,6 +509,11 @@ fn shard_manager( "--json-output".to_string(), ]; + if lora_ids != *"empty" { + shard_args.push("--lora-ids".to_string()); + shard_args.push(lora_ids.clone()); + } + // Activate trust remote code if trust_remote_code { shard_args.push("--trust-remote-code".to_string()); @@ -998,12 +1008,11 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L Ok(()) } - fn download_lora_adapters(args: &Args, running: Arc) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); - let mut download_args = vec![ + let download_args = vec![ "download-lora-adapters".to_string(), args.lora_ids.to_string(), ]; @@ -1035,15 +1044,6 @@ fn download_lora_adapters(args: &Args, running: Arc) -> Result<(), L envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; - // If args.weights_cache_override is some, pass it to the download process - // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = &args.weights_cache_override { - envs.push(( - "WEIGHTS_CACHE_OVERRIDE".into(), - weights_cache_override.into(), - )); - }; - // Start process tracing::info!("Starting LoRA adapter download process."); let mut download_process = match Command::new("text-generation-server") @@ -1155,6 +1155,7 @@ fn spawn_shards( let rope_factor = args.rope_factor; let max_batch_size = args.max_batch_size; let lora_ids = args.lora_ids.clone(); + thread::spawn(move || { shard_manager( model_id, @@ -1587,6 +1588,11 @@ fn main() -> Result<(), LauncherError> { let num_shard = find_num_shards(args.sharded, args.num_shard)?; if num_shard > 1 { + if matches!(args.quantize, Some(Quantization::Exl2)) { + return Err(LauncherError::ArgumentValidation( + "Sharding is currently not supported with `exl2` quantization".into(), + )); + } tracing::info!("Sharding model on {num_shard} processes"); } @@ -1631,7 +1637,9 @@ fn main() -> Result<(), LauncherError> { download_convert_model(&args, running.clone())?; // Download LoRA adapters - download_lora_adapters(&args, running.clone())?; + if args.lora_ids != *"empty" { + download_lora_adapters(&args, running.clone())?; + } if !running.load(Ordering::SeqCst) { // Launcher was asked to stop diff --git a/load_tests/Makefile b/load_tests/Makefile new file mode 100644 index 00000000..9199aa3b --- /dev/null +++ b/load_tests/Makefile @@ -0,0 +1,9 @@ + +ShareGPT_V3_unfiltered_cleaned_split.json: + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +prepare_share: ShareGPT_V3_unfiltered_cleaned_split.json + python filter.py + +prepare_orca: + python orca.py diff --git a/load_tests/common.js b/load_tests/common.js new file mode 100644 index 00000000..e0a10595 --- /dev/null +++ b/load_tests/common.js @@ -0,0 +1,94 @@ +import { check } from 'k6'; +import { scenario } from 'k6/execution'; +import http from 'k6/http'; +import { Trend, Counter } from 'k6/metrics'; + +const host = __ENV.HOST; +const model_id = __ENV.MODEL_ID; +const timePerToken = new Trend('time_per_token', true); +const tokens = new Counter('tokens'); +const new_tokens = new Counter('new_tokens'); +const input_tokens = new Counter('input_tokens'); +const max_new_tokens = 50; + +// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json")) +const shareGPT = JSON.parse(open("small.json")) + + +export function get_options() { + return { + thresholds: { + http_req_failed: ['rate==0'], + // time_per_token: [{ + // threshold: `p(50)<${5 * reference_latency_ms}`, + // abortOnFail: true, + // delayAbortEval: '10s' + // }], + }, + scenarios: { + // single_user: { + // executor: 'constant-arrival-rate', + // duration: '60s', + // preAllocatedVUs: 1, + // rate: 20, + // timeUnit: '1s', + // }, + load_test: { + executor: 'constant-arrival-rate', + duration: '60s', + preAllocatedVUs: 100, + rate: 1, + timeUnit: '1s', + }, + // breakpoint: { + // executor: 'ramping-arrival-rate', //Assure load increase if the system slows + // preAllocatedVUs: 300, + // stages: [ + // { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load + // ], + // }, + // throughput: { + // executor: 'shared-iterations', + // vus: 100, + // iterations: 200, + // maxDuration: '40s', + // }, + }, + }; +} + +function generate_payload(gpt, max_new_tokens) { + const input = gpt["conversations"][0]["value"]; + return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens } +} + +export const options = get_options(); + +export default function run() { + const headers = { 'Content-Type': 'application/json' }; + const query = shareGPT[scenario.iterationInTest % shareGPT.length]; + const payload = JSON.stringify(generate_payload(query, max_new_tokens)); + const res = http.post(`http://${host}/v1/chat/completions`, payload, { + headers, + }); + if (res.status >= 400 && res.status < 500) { + return; + } + + + check(res, { + 'Post status is 200': (res) => res.status === 200, + }); + const duration = res.timings.duration; + + if (res.status === 200) { + const body = res.json(); + const completion_tokens = body.usage.completion_tokens; + const latency_ms_per_token = duration / completion_tokens; + timePerToken.add(latency_ms_per_token); + const prompt_tokens = body.usage.prompt_tokens; + input_tokens.add(prompt_tokens); + new_tokens.add(completion_tokens); + tokens.add(completion_tokens + prompt_tokens); + } +} diff --git a/load_tests/filter.py b/load_tests/filter.py new file mode 100644 index 00000000..a00226ed --- /dev/null +++ b/load_tests/filter.py @@ -0,0 +1,26 @@ +import json + + +def main(): + with open("./ShareGPT_V3_unfiltered_cleaned_split.json", "r") as f: + data = json.load(f) + + # Select only the first 2k conversations that start with a human. + max = 2000 + conversations = [] + for conversation in data: + conv = conversation.get("conversations") + if conv and conv[0]["from"] == "human": + # Trim the rest of the output + conversation["conversations"] = conversation["conversations"][:1] + conversations.append(conversation) + + if len(conversation) >= max: + break + + with open("./small.json", "w") as f: + data = json.dump(conversations, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/load_tests/orca.py b/load_tests/orca.py new file mode 100644 index 00000000..e607d27c --- /dev/null +++ b/load_tests/orca.py @@ -0,0 +1,27 @@ +import json +import datasets +import tqdm + + +def main(): + dataset = datasets.load_dataset("Open-Orca/OpenOrca", split="train") + # Select only the first 2k conversations that start with a human. + max = min(2000, len(dataset)) + conversations = [] + for item in tqdm.tqdm(dataset, total=max): + conversation = { + "conversations": [ + {"from": "human", "value": item["question"]}, + ], + "id": item["id"], + } + conversations.append(conversation) + if len(conversations) >= max: + break + + with open("./small.json", "w") as f: + data = json.dump(conversations, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/router/Cargo.toml b/router/Cargo.toml new file mode 100644 index 00000000..b595f55b --- /dev/null +++ b/router/Cargo.toml @@ -0,0 +1,61 @@ +[package] +name = "text-generation-router" +description = "Text Generation Webserver" +build = "build.rs" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + +[dependencies] +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-client = { path = "client" } +clap = { version = "4.4.5", features = ["derive", "env"] } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = "0.21.1" +metrics-exporter-prometheus = { version = "0.12.1", features = [] } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +thiserror = "1.0.48" +tokenizers = { workspace = true} +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +ngrok = { version = "0.13.1", features = ["axum"], optional = true } +init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = "0.22.0" +bitstream-io = { version = "2.3.0" } + +[build-dependencies] +vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } + +[features] +default = ["ngrok"] +ngrok = ["dep:ngrok"] +google = [] diff --git a/router/README.md b/router/README.md new file mode 100644 index 00000000..5b1f9e36 --- /dev/null +++ b/router/README.md @@ -0,0 +1,93 @@ +# Router + +Also named `webserver` throughout the docs. + +This router is handling most of the logic to handle the "batches" tell +when to pass new `prefill` requests and pausing `decode` requests, which ones etc... + +It uses gRPC to communicate with the shards which can therefore be kept +much simpler and focus on having the most efficient forward passes as possible. + +## Continuous batching + +One important feature of `text-generation-inference` is enabled +by this `router`. + +Continuous batching is the act of regularly running queries in the same +`forward` step of the LLM (a "batch") and also removing them when they are +finished. + +In order for continuous batching to be useful, you need to have more compute available +with respect to the memory requirements of your model. This is essentially true for +LLMs and the larger the model, the truer it gets (since you have to pool multiple +GPUs to load the model, you effectively have a lot of compute power at your hands). + + +Static batching is the act of doing several queries at the same time, but usually +this is controlled by the client, and therefore the amount of batching is decided +beforehand. + +For text-generation, and LLMs which are memory bound we can try to be much more +efficient with the available compute, by having client sending us single queries, +and let the router mix&match queries into or out of batches to make the use the +compute the most efficiently. This is possible because for LLMs the total compute +for running the model is much bigger than doing mix&match of the batches themselves. + + +### Simple continuous batching + +text-generation works by feeding a prompt to a model, and iteratively calling +`forward` on the model to produce new text, 1 token at a time. + +The first idea is simple, when a query arrives, we start working on it directly. +When new queries arrive, we simply wait for the current `forward` to be finished +then batch the current running prompt with the new query, and call `forward`. + +Whenever either query is finished: either the model produce EOS (end of sentence) token +or the query reached the allowed limit. We simply drop it from the batch, remove +all the allocated memory and we can continue with the rest until nothing is left. + +This simple idea generalizes very well and we could potentially stack many requests +in the same batch. + +One thing to note, is that queries can be potentially run with different parameters +meaning different way to choose the next token (sampling, not sampling, temperature, top_k etc..). This is not problematic for the proposed approach we just need to do the sampling +independantly on each member of the batch. + +### Prefill, decode and past key values + +In order to make LLMs and text-generation efficient, there's actually a very powerful +trick that can be used, which is the "caching" of some attention matrices. [More on that +in the first part of this blog](https://huggingface.co/blog/accelerated-inference#getting-to-the-first-10x-speedup) + +What this means, is that the first "pass" of a prompt is different from the subsequent +"forward" passes. Since for the first one we have to compute the entire attention matrix, whereas in the follow-ups only require to compute the new token attention. +The first pass is called `prefill` throughout this codebase where as the follow-ups are called `decode`. + +Since `prefill` is much more expensive than `decode` we don't want to do it all the time, +but a currently running query is probably doing `decode`. If we want to do the continuous +batching as explained previously we need to run `prefill` at some point in order to create +the attention matrix required to be able to join the `decode` group. + +`text-generation-inference` uses a bunch of different strategies and parameters in +order to enable you to find the sweet spot between exploiting the hardware and perceived latency. + +With no continuous batching at all, latency is going to be super good, but throughput (meaning +the total number of requests allowed in a given timeframe) is going to be super bad (since it's essentially 1). + +With static batching, you can probably reach the maximum throughput (by using the maximum total batch size applicable to your hardware), but the latency is super bad since in order to have maximum throughput you need to wait for requests to come in before processing. + +With continuous batching you can find a sweet spot. In general latency is the most critical +parameter users care about. But a 2x latency slowdown for 10x more users on the same +hardware is an acceptable tradeoff. + +## Token streaming + +This is a very important aspect of client UX. As mentionned above, latency is the +most critical perceived quality of an LLM API. + +With token streaming, the server can start answering after the first `prefill` pass +directly, without waiting for all the generation to be done. For extremely long queries +this means clients can start to see something happening orders of magnitude before +the work is done. Seeing something in progress allows them to cut short if it's not +what's wanted but also it "feels" better. diff --git a/router/build.rs b/router/build.rs new file mode 100644 index 00000000..f5eb8a26 --- /dev/null +++ b/router/build.rs @@ -0,0 +1,26 @@ +use std::error::Error; +use vergen::EmitBuilder; + +fn main() -> Result<(), Box> { + // Try to get the git sha from the local git repository + if EmitBuilder::builder() + .fail_on_error() + .git_sha(false) + .emit() + .is_err() + { + // Unable to get the git sha + if let Ok(sha) = std::env::var("GIT_SHA") { + // Set it from an env var + println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); + } + } + + // Set docker label if present + if let Ok(label) = std::env::var("DOCKER_LABEL") { + // Set it from an env var + println!("cargo:rustc-env=DOCKER_LABEL={label}"); + } + + Ok(()) +} diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml new file mode 100644 index 00000000..d0131784 --- /dev/null +++ b/router/client/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "text-generation-client" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +futures = "^0.3" +grpc-metadata = { path = "../grpc-metadata" } +prost = "^0.12" +thiserror = "^1.0" +tokio = { version = "^1.32", features = ["sync"] } +tonic = "^0.10" +tower = "^0.4" +tracing = "^0.1" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" diff --git a/router/client/build.rs b/router/client/build.rs new file mode 100644 index 00000000..497be545 --- /dev/null +++ b/router/client/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/generate.proto"); + fs::create_dir("src/pb").unwrap_or(()); + + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index bb03ab83..222e9daf 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -147,7 +147,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, - lora_id: None + lora_id: None, }); n_tokens += max_input_length; diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs new file mode 100644 index 00000000..6782d9ff --- /dev/null +++ b/router/client/src/lib.rs @@ -0,0 +1,46 @@ +//! Text Generation gRPC client library + +mod client; +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::InfoResponse as ShardInfo; +pub use pb::generate::v2::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +pub type Result = std::result::Result; diff --git a/router/client/src/pb/.gitignore b/router/client/src/pb/.gitignore new file mode 100644 index 00000000..6f5f3d11 --- /dev/null +++ b/router/client/src/pb/.gitignore @@ -0,0 +1 @@ +*.rs diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs new file mode 100644 index 00000000..e1e52d59 --- /dev/null +++ b/router/client/src/sharded_client.rs @@ -0,0 +1,187 @@ +use crate::client::{DecodeTimings, PrefillTimings}; +/// Multi shard Client +use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; +use crate::{ClientError, Result}; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} diff --git a/router/grpc-metadata/Cargo.toml b/router/grpc-metadata/Cargo.toml new file mode 100644 index 00000000..da163ec5 --- /dev/null +++ b/router/grpc-metadata/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "grpc-metadata" +version = "0.1.0" +edition = "2021" + +[dependencies] +opentelemetry = "^0.20" +tonic = "^0.10" +tracing = "^0.1" +tracing-opentelemetry = "^0.21" diff --git a/router/grpc-metadata/src/lib.rs b/router/grpc-metadata/src/lib.rs new file mode 100644 index 00000000..3068a61c --- /dev/null +++ b/router/grpc-metadata/src/lib.rs @@ -0,0 +1,41 @@ +//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request. +//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples + +use opentelemetry::global; +use opentelemetry::propagation::Injector; +use tracing_opentelemetry::OpenTelemetrySpanExt; + +/// Inject context in the metadata of a gRPC request. +struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap); + +impl<'a> Injector for MetadataInjector<'a> { + /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs + fn set(&mut self, key: &str, value: String) { + if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { + if let Ok(val) = value.parse() { + self.0.insert(key, val); + } + } + } +} + +/// Get a context from the global context and inject the span into a gRPC request's metadata. +fn inject(metadata: &mut tonic::metadata::MetadataMap) { + global::get_text_map_propagator(|propagator| { + propagator.inject_context( + &tracing::Span::current().context(), + &mut MetadataInjector(metadata), + ) + }) +} + +pub trait InjectTelemetryContext { + fn inject_context(self) -> Self; +} + +impl InjectTelemetryContext for tonic::Request { + fn inject_context(mut self) -> Self { + inject(self.metadata_mut()); + self + } +} diff --git a/router/src/config.rs b/router/src/config.rs new file mode 100644 index 00000000..d27b1136 --- /dev/null +++ b/router/src/config.rs @@ -0,0 +1,216 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "model_type")] +#[serde(rename_all = "snake_case")] +pub struct LlavaNext { + text_config: TextConfig, + vision_config: VisionConfig, + image_grid_pinpoints: Vec<(usize, usize)>, +} + +fn get_anyres_image_grid_shape( + height: usize, + width: usize, + grid_pinpoints: &[(usize, usize)], + patch_size: usize, +) -> (usize, usize) { + let (height, width) = select_best_resolution(height, width, grid_pinpoints); + (height / patch_size, width / patch_size) +} + +/// Selects the best resolution from a list of possible resolutions based on the original size. +/// This is done by calculating the effective and wasted resolution for each possible resolution. +/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution. +fn select_best_resolution( + original_height: usize, + original_width: usize, + possible_resolutions: &[(usize, usize)], +) -> (usize, usize) { + let mut best_fit = None; + let mut max_effective_resolution = 0; + let mut min_wasted_resolution = f32::NEG_INFINITY; + + for (height, width) in possible_resolutions { + let wscale = *width as f32 / original_width as f32; + let hscale = *height as f32 / original_height as f32; + // f32 partial ord. + let scale = if wscale > hscale { hscale } else { wscale }; + let downscaled_width = (*width as f32 * scale) as usize; + let downscaled_height = (*height as f32 * scale) as usize; + let effective_resolution = std::cmp::min( + downscaled_width * downscaled_height, + original_width * original_height, + ); + let wasted_resolution = (width * height) - effective_resolution; + + if effective_resolution > max_effective_resolution + || (effective_resolution == max_effective_resolution + && (wasted_resolution as f32) < min_wasted_resolution) + { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution as f32; + best_fit = Some((*height, *width)); + } + } + + best_fit.unwrap_or((original_height, original_width)) +} + +fn get_unpadded_features( + height: usize, + width: usize, + npatches: usize, + num_patch_height: usize, + num_patch_width: usize, +) -> (usize, usize) { + let current_height = npatches * num_patch_height; + let current_width = npatches * num_patch_width; + + let aspect_ratio: f64 = width as f64 / height as f64; + let current_aspect_ratio: f64 = current_width as f64 / current_height as f64; + let (current_height, current_width) = if aspect_ratio > current_aspect_ratio { + let new_height = (height * current_width) / width; + (new_height, current_width) + } else { + let new_width = (width * current_height) / height; + (current_height, new_width) + }; + + let unpadded_features = current_height * current_width; + let newline_features = current_height; + (unpadded_features, newline_features) +} + +impl LlavaNext { + pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { + let image_size = self.vision_config.image_size; + let patch_size = self.vision_config.patch_size; + assert!(image_size % patch_size == 0); + let npatches = image_size / patch_size; + let (num_patch_height, num_patch_width) = + get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); + + let (unpadded_features, newline_features) = + get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width); + // The base patch covers the entire image + let base_features = npatches.pow(2); + unpadded_features + newline_features + base_features + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct ClipVisionModel { + image_size: usize, + patch_size: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Idefics2 {} + +impl Idefics2 { + pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { + 320 + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct PaliTextConfig { + num_image_tokens: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Paligemma { + text_config: PaliTextConfig, +} + +impl Paligemma { + pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { + self.text_config.num_image_tokens + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "model_type")] +#[serde(rename_all = "snake_case")] +pub enum Config { + LlavaNext(LlavaNext), + ClipVisionModel(ClipVisionModel), + Mistral, + Idefics, + Idefics2(Idefics2), + Ssm, + GptBigcode, + Santacoder, + Bloom, + Mpt, + Gpt2, + GptNeox, + Phi, + #[serde(rename = "phi-msft")] + PhiMsft, + Phi3, + Llama, + Baichuan, + Paligemma(Paligemma), + Gemma, + Cohere, + Drbx, + Falcon, + Mixtral, + Starcoder2, + Qwen2, + Opt, + T5, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct TextConfig {} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct VisionConfig { + image_size: usize, + patch_size: usize, +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_llava_next_features() { + let config = LlavaNext { + text_config: TextConfig {}, + vision_config: VisionConfig { + image_size: 336, + patch_size: 14, + }, + image_grid_pinpoints: vec![ + (336, 672), + (672, 336), + (672, 672), + (1008, 336), + (336, 1008), + ], + }; + + let slots = config.get_number_of_features(20, 20); + assert_eq!(slots, 1176); + let slots = config.get_number_of_features(640, 640); + assert_eq!(slots, 2928); + let slots = config.get_number_of_features(480, 640); + assert_eq!(slots, 2340); + let slots = config.get_number_of_features(899, 1024); + assert_eq!(slots, 2634); + let slots = config.get_number_of_features(1024, 899); + assert_eq!(slots, 2640); + let slots = config.get_number_of_features(1067, 1600); + assert_eq!(slots, 2144); + } +} diff --git a/router/src/health.rs b/router/src/health.rs index c4a7b05b..7c3fda01 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -55,7 +55,7 @@ impl Health { ignore_eos_token: false, }), top_n_tokens: 0, - lora_id: None + lora_id: None, }; let batch = Batch { id: BATCH_ID, diff --git a/router/src/infer.rs b/router/src/infer.rs new file mode 100644 index 00000000..0410de7d --- /dev/null +++ b/router/src/infer.rs @@ -0,0 +1,1621 @@ +/// Batching and inference logic +use crate::validation::{Validation, ValidationError}; +use crate::{ + ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, + HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, + TextMessage, Token, +}; +use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; +use futures::future::try_join_all; +use minijinja::{Environment, ErrorKind, Template}; +use nohash_hasher::IntMap; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use text_generation_client::{ + Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens, +}; +use thiserror::Error; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; +use tracing::{info_span, instrument, Instrument, Span}; + +/// Inference struct +#[derive(Clone)] +pub struct Infer { + /// Validation + validation: Validation, + /// Request queue + queue: Queue, + /// Shared state + shared: Arc, + /// Chat template + chat_template: Option, + /// Inference limit + limit_concurrent_requests: Arc, +} + +/// Infer shared state +struct Shared { + /// Batching background Tokio task notifier + batching_task: Notify, +} + +/// Raise a exception (custom function) used in the chat templates +fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) +} + +impl Infer { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + validation: Validation, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + max_concurrent_requests: usize, + requires_padding: bool, + window_size: Option, + speculate: u32, + generation_health: Arc, + tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, + ) -> Self { + let queue = Queue::new(requires_padding, 16, window_size, speculate); + let shared = Arc::new(Shared { + batching_task: Notify::new(), + }); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + shared.clone(), + generation_health, + )); + + let chat_template = tokenizer_config + .chat_template + .or(processor_config.chat_template) + .and_then(|t| match t { + ChatTemplateVersions::Single(template) => Some(template), + ChatTemplateVersions::Multiple(templates) => templates + .into_iter() + .find(|t| t.name == "default") + .map(|t| t.template), + }) + .map(|t| { + // .strip() is not supported in minijinja + // .capitalize() is not supported in minijinja but we can use | capitalize + let t = t + .replace(".strip()", " | trim") + .replace(".capitalize()", " | capitalize"); + ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) + }); + + // Inference limit with a semaphore + let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + + Self { + validation, + queue, + shared, + chat_template, + limit_concurrent_requests: semaphore, + } + } + + /// Add a new request to the queue and return a stream of InferStreamResponse + #[instrument(skip_all)] + pub(crate) async fn generate_stream( + &self, + request: GenerateRequest, + ) -> Result { + // Limit concurrent requests by acquiring a permit from the semaphore + let permit = self + .clone() + .limit_concurrent_requests + .try_acquire_owned() + .map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); + tracing::error!("{err}"); + err + })?; + + // Validate request + let valid_request = self.validation.validate(request).await.map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + err + })?; + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = valid_request.input_length; + + // Append the request to the queue + self.queue.append(Entry { + request: valid_request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.shared.batching_task.notify_one(); + + // Return stream + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) + } + + /// Tokenizer the input + #[instrument(skip_all)] + pub(crate) async fn tokenize( + &self, + request: GenerateRequest, + ) -> Result, InferError> { + // Tokenize request + let inputs = request.inputs; + let truncate = request.parameters.truncate; + let encoding = self + .validation + .tokenize(inputs, truncate) + .await + .map_err(|err| { + tracing::error!("Tokenization {err}"); + err + })?; + + // Return Encoding + Ok(encoding.map(|(encoding, _)| encoding)) + } + + /// Apply the chat template to the chat request + #[instrument(skip_all)] + pub(crate) fn apply_chat_template( + &self, + messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + self.chat_template + .as_ref() + .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? + .apply(messages, grammar_with_prompt) + .map_err(|e| { + metrics::increment_counter!("tgi_request_failure", "err" => "template"); + tracing::error!("{e}"); + e + }) + } + + /// Add a new request to the queue and return a InferResponse + #[instrument(skip_all)] + pub(crate) async fn generate( + &self, + request: GenerateRequest, + ) -> Result { + let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); + + // Create stream and keep semaphore permit as long as generate lives + let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + + // Return values + let mut result_prefill = Vec::new(); + let mut result_tokens = Vec::new(); + let mut result_top_tokens = Vec::new(); + let mut result_generated_text = None; + let mut result_start = None; + let mut result_queued = None; + + // Iterate on stream + while let Some(response) = stream.next().await { + match response? { + // Add prefill tokens + InferStreamResponse::Prefill(tokens) => { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + result_prefill = tokens + .ids + .into_iter() + .zip(tokens.logprobs.into_iter()) + .zip(tokens.texts.into_iter()) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + } + // Push last token + InferStreamResponse::Intermediate { token, top_tokens } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + } + // Final message + // Set return values + InferStreamResponse::End { + token, + generated_text, + start, + queued, + top_tokens, + } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + result_generated_text = Some(generated_text); + result_start = Some(start); + result_queued = Some(queued) + } + } + } + + // Check that we received a `InferStreamResponse::End` message + if let (Some(generated_text), Some(queued), Some(start)) = + (result_generated_text, result_queued, result_start) + { + Ok(InferResponse { + prefill: result_prefill, + _input_length, + tokens: result_tokens, + generated_text, + queued, + start, + top_tokens: if use_top_tokens { + result_top_tokens + } else { + Vec::new() + }, + }) + } else { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); + Err(err) + } + } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with + /// the highest log probability per token + #[instrument(skip(self, request))] + pub(crate) async fn generate_best_of( + &self, + request: GenerateRequest, + best_of: usize, + ) -> Result<(InferResponse, Vec), InferError> { + // validate best_of parameter separately + let best_of = self.validation.validate_best_of(best_of)?; + + // create multiple generate requests + let mut infer_responses: Vec = + try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; + + // get the sequence with the highest log probability per token + let mut max_index = 0; + let mut max_logprob: f32 = f32::MIN; + + for (i, response) in infer_responses.iter().enumerate() { + // mean logprobs of the generated tokens + let sequence_logprob = response + .tokens + .iter() + .map(|token| token.logprob) + .sum::() + / response.tokens.len() as f32; + + // set best sequence + if sequence_logprob > max_logprob { + max_index = i; + max_logprob = sequence_logprob; + } + } + let best_response = infer_responses.remove(max_index); + Ok((best_response, infer_responses)) + } +} + +#[derive(Clone)] +struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { + fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + let mut env = Box::new(Environment::new()); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); + + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); + + Self { + template, + bos_token, + eos_token, + use_default_tool_template, + } + } + + fn apply( + &self, + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text(Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), + })); + } + } + } + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) + } +} + +pub struct ToolGrammar {} + +impl ToolGrammar { + pub fn apply( + tools: Option>, + tool_choice: Option, + ) -> Result, InferError> { + if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { + // let tool_prompt = tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .unwrap_or_else(|| panic!("Tool with name {} not found", name)) + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + return Ok(Some(tools)); + } + // Err(InferError::ToolError("No tools provided".to_string())) + Ok(None) + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + shared: Arc, + generation_health: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + shared.batching_task.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size", batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + } else { + metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + prefill(&mut client, new_batch, &mut new_entries, &generation_health) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries, &generation_health) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size", 0.0); + metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, + generation_health: &Arc, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); + + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + // Update health + generation_health.store(false, Ordering::SeqCst); + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, + generation_health: &Arc, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); + + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + } + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + generation_health.store(false, Ordering::SeqCst); + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: generated_text.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +#[derive(Debug)] +pub(crate) enum InferStreamResponse { + // Optional first message + Prefill(Tokens), + // Intermediate messages + Intermediate { + token: Token, + top_tokens: Vec, + }, + // Last message + End { + token: Token, + top_tokens: Vec, + generated_text: GeneratedText, + start: Instant, + queued: Instant, + }, +} + +#[derive(Debug)] +pub(crate) struct InferResponse { + /// input_length is the input as perceived by the rust tokenizer in the + /// validation pathway. It is redundant with prefill.len() but prefill + /// has data only if the user asked for it. This will always be filled. + pub(crate) _input_length: u32, + pub(crate) prefill: Vec, + pub(crate) tokens: Vec, + pub(crate) generated_text: GeneratedText, + pub(crate) queued: Instant, + pub(crate) start: Instant, + pub(crate) top_tokens: Vec>, +} + +#[derive(Debug, Error)] +pub enum InferError { + #[error("Request failed during generation: {0}")] + GenerationError(String), + #[error("Model is overloaded")] + Overloaded(#[from] TryAcquireError), + #[error("Input validation error: {0}")] + ValidationError(#[from] ValidationError), + #[error("Incomplete generation")] + IncompleteGeneration, + #[error("Template error: {0}")] + TemplateError(#[from] minijinja::Error), + #[error("Tool error: {0}")] + ToolError(String), +} + +impl InferError { + pub(crate) fn error_type(&self) -> &str { + match self { + InferError::GenerationError(_) => "generation", + InferError::Overloaded(_) => "overloaded", + InferError::ValidationError(_) => "validation", + InferError::IncompleteGeneration => "incomplete_generation", + InferError::TemplateError(_) => "template_error", + InferError::ToolError(_) => "tool_error", + } + } +} + +// tests +#[cfg(test)] +mod tests { + use crate::infer::raise_exception; + use crate::{ChatTemplateInputs, TextMessage}; + use minijinja::Environment; + + #[test] + fn test_chat_template() { + let env = Environment::new(); + + let source = r#" + {% for message in messages %} + {% if message['role'] == 'system' %} + {% if message['content']%} + {{'### System:\n' + message['content']+'\n\n'}} + {% endif %} + {% elif message['role'] == 'user' %} + {{'### User:\n' + message['content']+'\n\n'}} + {% elif message['role'] == 'assistant' %} + {{'### Assistant:\n' + message['content']}} + {% endif %} + {% if loop.last and add_generation_prompt %} + {{ '### Assistant:\n' }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + + assert_eq!( + result, + "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n" + ); + } + + #[test] + fn test_chat_template_invalid_with_raise() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {{ bos_token }} + {% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if message['role'] == 'user' %} + {{ '[INST] ' + message['content'] + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token}} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "Hi again!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); + + match result { + Ok(_) => panic!("Should have failed"), + Err(e) => { + assert_eq!( + e.detail().unwrap(), + "Conversation roles must alternate user/assistant/user/assistant/..." + ); + } + } + } + + #[test] + fn test_chat_template_valid_with_raise() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {{ bos_token }} + {% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if message['role'] == 'user' %} + {{ '[INST] ' + message['content'] + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token}} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); + } + + #[test] + fn test_chat_template_valid_with_add_generation_prompt() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {% for message in messages %} + {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} + {% endfor %} + {% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} + {% endif %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); + } + + struct ChatTemplateTestItem { + name: &'static str, + chat_template: &'static str, + input: ChatTemplateInputs<'static>, + target: &'static str, + } + + #[test] + fn test_many_chat_templates() { + let example_chat = vec![ + TextMessage { + role: "user".to_string(), + content: "Hello, how are you?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "I'm doing great. How can I help you today?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "I'd like to show off how chat templating works!".to_string(), + }, + ]; + + let example_chat_with_system = [TextMessage { + role: "system".to_string(), + content: "You are a friendly chatbot who always responds in the style of a pirate" + .to_string(), + }] + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); + + let test_default_templates = vec![ + ChatTemplateTestItem { + name: "_base", + chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "blenderbot", + chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "blenderbot_small", + chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "bloom", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "gpt_neox", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, + ChatTemplateTestItem { + name: "gpt2", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, + ChatTemplateTestItem { + name: "llama", + // NOTE: the `.strip()` has been replaced with `| trim` in the following template + chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content | trim + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "whisper", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, + ]; + + #[allow(unused_variables)] // name is unused + for ChatTemplateTestItem { + name, + chat_template, + input, + target, + } in test_default_templates + { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + let tmpl = env.template_from_str(chat_template); + let result = tmpl.unwrap().render(input).unwrap(); + assert_eq!(result, target); + } + + let test_custom_templates = vec![ + ChatTemplateTestItem { + name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=false)", + chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHello, how are you?<|assistant|>\nI'm doing great. How can I help you today?<|user|>\nI'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=true)", + chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: vec![ + TextMessage{ + role: "system".to_string(), + content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), + }, + TextMessage{ + role: "user".to_string(), + content: "How many helicopters can a human eat in one sitting?".to_string(), + }, + ], + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>", + }, + ChatTemplateTestItem { + name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", + chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "mistralai/Mistral-7B-Instruct-v0.1", + chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "mistralai/Mixtral-8x7B-Instruct-v0.1", + chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b", + chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "openchat/openchat-3.5-0106", + // `.title()` has been replaced with `| upper` in the following template + chat_template: "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + (message['role'] | title) + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>", + }, + ChatTemplateTestItem { + name: "upstage/SOLAR-10.7B-Instruct-v1.0", + chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "codellama/CodeLlama-70b-Instruct-hf", + // NOTE: `.strip()` has been replaced with `| trim` in the following template + chat_template: "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\\n\\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\\nDestination: user\\n\\n '}}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n ", + }, + ChatTemplateTestItem { + name: "Deci/DeciLM-7B-instruct", + chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "Qwen/Qwen1.5-72B-Chat", + chat_template: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "deepseek-ai/deepseek-llm-7b-chat", + chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\\n\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|begin▁of▁sentence|>"), + eos_token: Some("<|end▁of▁sentence|>"), + ..Default::default() + }, + target: "<|begin▁of▁sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\n", + }, + ChatTemplateTestItem { + name: "h2oai/h2o-danube-1.8b-chat", + chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "internlm/internlm2-chat-7b", + chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", + }, + ChatTemplateTestItem { + name: "TheBloke/deepseek-coder-33B-instruct-AWQ", + chat_template: "{%- set found_item = false -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set found_item = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|begin▁of▁sentence|>"), + eos_token: Some("<|EOT|>"), + ..Default::default() + }, + target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n", + }, + ChatTemplateTestItem { + name: "ericzzz/falcon-rw-1b-chat", + // `.strip()` has been replaced with `| trim` in the following template + chat_template: "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'] | trim }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim }}{% elif message['role'] == 'assistant' %}{{ '[RESP] ' + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|endoftext|>"), + eos_token: Some("<|endoftext|>"), + ..Default::default() + }, + target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "abacusai/Smaug-34B-v0.1", + chat_template: "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", + }, + ChatTemplateTestItem { + name: "maywell/Synatra-Mixtral-8x7B", + chat_template: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!", + }, + ChatTemplateTestItem { + name: "deepseek-ai/deepseek-coder-33b-instruct", + chat_template: "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some("<|begin▁of▁sentence|>"), + eos_token: Some(""), + ..Default::default() + }, + target: "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n", + }, + // NOT INCLUDED + // - meetkai/functionary-medium-v2.2 + // - fireworks-ai/firefunction-v1 + // https://github + ChatTemplateTestItem { + name: "maywell/PiVoT-MoE", + chat_template: "{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + ..Default::default() + }, + target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", + }, + ]; + + #[allow(unused_variables)] // name is unused + for ChatTemplateTestItem { + name, + chat_template, + input, + target, + } in test_custom_templates + { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + // trim all the whitespace + let chat_template = chat_template + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&chat_template); + let result = tmpl.unwrap().render(input).unwrap(); + assert_eq!(result, target); + } + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 2ed1d423..8fb0e816 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -41,6 +41,12 @@ pub(crate) struct VertexResponse { pub predictions: Vec, } +#[derive(Deserialize, ToSchema)] +pub(crate) struct LoRAAdapterControlRequest { + pub lora_id: String, + pub hf_api_token: Option, +} + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { @@ -80,6 +86,20 @@ impl HubTokenizerConfig { } } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct HubProcessorConfig { + pub chat_template: Option, + pub image_seq_len: usize, + pub processor_class: Option, +} + +impl HubProcessorConfig { + pub fn from_file>(filename: P) -> Option { + let content = std::fs::read_to_string(filename).ok()?; + serde_json::from_str(&content).ok() + } +} + #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { @@ -328,7 +348,7 @@ fn default_parameters() -> GenerateParameters { seed: None, top_n_tokens: None, grammar: None, - lora_id: None + lora_id: None, } } @@ -413,7 +433,7 @@ pub struct CompletionRequest { #[serde(default)] #[schema(nullable = true, example = "null")] pub stop: Option>, - + /// LoRA id #[serde(default)] #[schema(nullable = true, default = "empty", example = "empty")] diff --git a/router/src/main.rs b/router/src/main.rs new file mode 100644 index 00000000..b526367c --- /dev/null +++ b/router/src/main.rs @@ -0,0 +1,572 @@ +use axum::http::HeaderValue; +use clap::Parser; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use hf_hub::{Cache, Repo, RepoType}; +use opentelemetry::sdk::propagation::TraceContextPropagator; +use opentelemetry::sdk::trace; +use opentelemetry::sdk::trace::Sampler; +use opentelemetry::sdk::Resource; +use opentelemetry::{global, KeyValue}; +use opentelemetry_otlp::WithExportConfig; +use std::fs::File; +use std::io::BufReader; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::{Path, PathBuf}; +use text_generation_client::{ClientError, ShardedClient}; +use text_generation_router::config::Config; +use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; +use thiserror::Error; +use tokenizers::Tokenizer; +use tower_http::cors::AllowOrigin; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + validation_workers, + json_output, + otlp_endpoint, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + } = args; + + // Launch Tokio runtime + init_logging(otlp_endpoint, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let cors_allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { + AllowOrigin::list( + cors_allow_origin + .iter() + .map(|origin| origin.parse::().unwrap()), + ) + }); + + // Parse Huggingface hub token + let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + + // Tokenizer instance + // This will only be used to validate payloads + let local_path = Path::new(&tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, + } + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = Cache::default(); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None + } + } + } + } else { + Type::None + }; + + // Load tokenizer and model info + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + processor_config_filename, + model_info, + ) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("processor_config.json")), + None, + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + + let model_info = if let Some(model_info) = get_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + processor_config_filename, + model_info, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + repo.get("processor_config.json"), + None, + ) + } + }; + let tokenizer: Option = + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); + let config: Option = config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } + config.ok() + }) + }); + let model_info = model_info.unwrap_or_else(|| HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + let processor_config = processor_config_filename + .and_then(HubProcessorConfig::from_file) + .unwrap_or_default(); + + tracing::info!("Using config {config:?}"); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } + + // if pipeline-tag == text-generation we default to return_full_text = true + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + true + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + + // Instantiate sharded client from the master unix socket + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(RouterError::Connection)?; + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(RouterError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(RouterError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_supported_batch_total_tokens = match sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(RouterError::Warmup)? + { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + max_batch_total_tokens + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); + } + + max_supported_batch_total_tokens + } + }; + tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); + tracing::info!("Connected"); + + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; + + // Run server + server::run( + model_info, + shard_info, + compat_return_full_text, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_supported_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + sharded_client, + tokenizer, + config, + validation_workers, + addr, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + tokenizer_config, + processor_config, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + ) + .await?; + Ok(()) +} + +/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: +/// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) +/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) +/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) +fn init_logging(otlp_endpoint: Option, json_output: bool) { + let mut layers = Vec::new(); + + // STDOUT/STDERR layer + let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_ansi(ansi) + .with_line_number(true); + + let fmt_layer = match json_output { + true => fmt_layer.json().flatten_event(true).boxed(), + false => fmt_layer.boxed(), + }; + layers.push(fmt_layer); + + // OpenTelemetry tracing layer + if let Some(otlp_endpoint) = otlp_endpoint { + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(otlp_endpoint), + ) + .with_trace_config( + trace::config() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + "text-generation-inference.router", + )])) + .with_sampler(Sampler::AlwaysOn), + ) + .install_batch(opentelemetry::runtime::Tokio); + + if let Ok(tracer) = tracer { + layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); + init_tracing_opentelemetry::init_propagator().unwrap(); + }; + } + + // Filter events with LOG_LEVEL + let varname = "LOG_LEVEL"; + let env_filter = if let Ok(log_level) = std::env::var(varname) { + // Override to avoid simple logs to be spammed with tokio level informations + let log_level = match &log_level[..] { + "warn" => "text_generation_launcher=warn,text_generation_router=warn", + "info" => "text_generation_launcher=info,text_generation_router=info", + "debug" => "text_generation_launcher=debug,text_generation_router=debug", + log_level => log_level, + }; + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .parse_lossy(log_level) + } else { + EnvFilter::new("info") + }; + + tracing_subscriber::registry() + .with(env_filter) + .with(layers) + .init(); +} + +/// get model info from the Huggingface Hub +pub async fn get_model_info(api: &ApiRepo) -> Option { + let response = api.info_request().send().await.ok()?; + + if response.status().is_success() { + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None + } +} + +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + api_base_repo.get("tokenizer.json").await.ok() + } else { + None + } +} + +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) + .map_err(|e| { + tracing::warn!("Unable to parse tokenizer config: {}", e); + e + }) + .ok()?; + + Some(tokenizer_config) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), + #[error("Axum webserver failed: {0}")] + Axum(#[from] axum::BoxError), +} diff --git a/router/src/queue.rs b/router/src/queue.rs index 4af0585e..82334d59 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -283,7 +283,7 @@ impl State { parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), top_n_tokens: entry.request.top_n_tokens, - lora_id: Some(entry.request.lora_id.clone()) + lora_id: Some(entry.request.lora_id.clone()), }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -390,6 +390,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + lora_id: "empty".to_string(), }, response_tx, span: info_span!("entry"), diff --git a/router/src/server.rs b/router/src/server.rs index 1ba423d7..203c2f69 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,15 +5,16 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, - PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, - Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Infer, + Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, + TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, + CompletionRequest, DeltaToolCall, Function, LoRAAdapterControlRequest, Tool, VertexRequest, + VertexResponse, }; use crate::{FunctionDefinition, ToolCall, ToolType}; use async_stream::__private::AsyncStream; @@ -31,9 +32,17 @@ use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; +use std::ffi::OsString; +use std::io::{BufRead, BufReader}; use std::net::SocketAddr; +use std::os::unix::process::CommandExt; +use std::process::{Command, Stdio}; use std::sync::atomic::AtomicBool; +use std::sync::mpsc; use std::sync::Arc; +use std::thread::sleep; +use std::time::Duration; +use std::{env, thread}; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; use tokio::select; @@ -665,7 +674,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, - lora_id: None + lora_id: None, }, }) .collect(); @@ -1090,7 +1099,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar: typed_grammar, - lora_id: None + lora_id: None, }, }; @@ -1335,7 +1344,8 @@ async fn tokenize( .iter() .zip(encoding.get_offsets()) .map(|(&id, &(start, stop))| { - let text: String = input.chars().skip(start).take(stop - start).collect(); + let text: String = + String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); SimpleToken { id, text, @@ -1367,6 +1377,103 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +/// Tokenize inputs +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/lora_adapter_control", + request_body = LoRAAdapterControlRequest, + responses( + (status = 200, description = "LoRA Adapter Control Response", body = LoRAAdapterControlResponse), + (status = 400, description = "TGI server not found", body = ErrorResponse, + example = json ! ({"error": "TGI server not found"})), + (status = 404, description = "No model path found", body = ErrorResponse, + example = json ! ({"error": "No model path found"})), + ) +)] +#[instrument(skip_all)] +async fn download_lora_adapter( + Extension(_infer): Extension, + Json(req): Json, +) -> Result<(), (StatusCode, Json)> { + let download_args = vec![ + "download-lora-adapters".to_string(), + req.lora_id.to_string(), + ]; + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Enable hf transfer for insane download speeds + let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); + envs.push(( + "HF_HUB_ENABLE_HF_TRANSFER".into(), + enable_hf_transfer.into(), + )); + + // Parse Inference API token + if let Some(token) = req.hf_api_token { + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), token.into())); + } + + // Start process + tracing::info!("Starting LoRA adapter download process."); + let mut download_process = match Command::new("text-generation-server") + .args(download_args) + .env_clear() + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(_) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: "TGI server not found.".to_string(), + error_type: "TGI server not found".to_string(), + }), + )) + } + }; + + let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); + + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in download_stderr.lines().map_while(Result::ok) { + err_sender.send(line).unwrap_or(()); + } + }); + + loop { + if let Some(status) = download_process.try_wait().unwrap() { + if status.success() { + tracing::info!("Successfully downloaded weights."); + break; + } + + let mut err = String::new(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } + + return Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No model path found or authorization failed.".to_string(), + error_type: "download error".to_string(), + }), + )); + } + sleep(Duration::from_millis(100)); + } + Ok(()) +} + #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); @@ -1394,9 +1501,10 @@ pub async fn run( addr: SocketAddr, allow_origin: Option, ngrok: bool, - ngrok_authtoken: Option, - ngrok_edge: Option, + _ngrok_authtoken: Option, + _ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, @@ -1497,6 +1605,7 @@ pub async fn run( shard_info.speculate, generation_health, tokenizer_config, + processor_config, ); // Duration buckets @@ -1632,11 +1741,15 @@ pub async fn run( let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); + let lora_control_route = + Router::new().route("/download_lora_adapter", post(download_lora_adapter)); + // Combine routes and layers let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) - .merge(aws_sagemaker_route); + .merge(aws_sagemaker_route) + .merge(lora_control_route); #[cfg(feature = "google")] { @@ -1666,46 +1779,9 @@ pub async fn run( if ngrok { #[cfg(feature = "ngrok")] { - use ngrok::config::TunnelBuilder; - - let _ = addr; - - let authtoken = - ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); - - let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling"); - - let tunnel = ngrok::Session::builder() - .authtoken(authtoken) - .connect() - .await - .unwrap() - .labeled_tunnel() - .label("edge", edge); - - let listener = tunnel.listen().await.unwrap(); - - // Run prom metrics and health locally too - tokio::spawn( - axum::Server::bind(&addr) - .serve( - Router::new() - .route("/health", get(health)) - .route("/metrics", get(metrics)) - .layer(Extension(health_ext)) - .layer(Extension(prom_handle)) - .into_make_service(), - ) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()), - ); + panic!("ngrok feature is not functional with axum=0.7 and hyper=1, waiting on https://github.com/ngrok/ngrok-rust/pull/137/files to re-enable."); // Run server - axum::Server::builder(listener) - .serve(app.into_make_service()) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()) - .await?; } #[cfg(not(feature = "ngrok"))] { @@ -1718,9 +1794,9 @@ pub async fn run( } } else { // Run server - axum::Server::bind(&addr) - .serve(app.into_make_service()) - // Wait until all requests are finished to shut down + + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await?; } diff --git a/router/src/validation.rs b/router/src/validation.rs index 20390457..6235f1b5 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -13,7 +13,7 @@ use thiserror::Error; use tokenizers::tokenizer::Tokenizer; // use tokenizers::TruncationDirection; use base64::{engine::general_purpose::STANDARD, Engine}; -use hf_hub::{Cache, Repo, RepoType}; +use hf_hub::{Cache, Repo}; use image::{io::Reader as ImageReader, ImageFormat}; use tokio::sync::mpsc; use tokio::sync::oneshot; @@ -387,8 +387,8 @@ impl Validation { let cache = Cache::default(); let repo = cache.repo(Repo::model(lid.clone())); match repo.get("adapter_model.bin") { - Some(_) => { lid } - None => return Err(ValidationError::LoRANotLoaded(lid)) + Some(_) => lid, + None => return Err(ValidationError::LoRANotLoaded(lid)), } } }; @@ -401,7 +401,7 @@ impl Validation { parameters, stopping_parameters, top_n_tokens, - lora_id: loraid + lora_id: loraid, }) } @@ -660,7 +660,7 @@ pub(crate) struct ValidGenerateRequest { pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, pub top_n_tokens: u32, - pub lora_id: String + pub lora_id: String, } #[derive(Error, Debug)] @@ -726,7 +726,7 @@ pub enum ValidationError { #[error("Could not fetch image: {0}")] FailedFetchImage(#[from] reqwest::Error), #[error("LoRA adaptor {0} not loaded")] - LoRANotLoaded(String) + LoRANotLoaded(String), } #[cfg(test)] diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..83f9a5b0 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,5 @@ +[toolchain] +# Released on: 13 June, 2024 +# https://releases.rs/docs/1.79.0/ +channel = "1.79.0" +components = ["rustfmt", "clippy"] diff --git a/sagemaker-entrypoint.sh b/sagemaker-entrypoint.sh new file mode 100755 index 00000000..9ac47010 --- /dev/null +++ b/sagemaker-entrypoint.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +if [[ -z "${HF_MODEL_ID}" ]]; then + echo "HF_MODEL_ID must be set" + exit 1 +fi +export MODEL_ID="${HF_MODEL_ID}" + +if [[ -n "${HF_MODEL_REVISION}" ]]; then + export REVISION="${HF_MODEL_REVISION}" +fi + +if [[ -n "${SM_NUM_GPUS}" ]]; then + export NUM_SHARD="${SM_NUM_GPUS}" +fi + +if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then + export QUANTIZE="${HF_MODEL_QUANTIZE}" +fi + +if [[ -n "${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then + export TRUST_REMOTE_CODE="${HF_MODEL_TRUST_REMOTE_CODE}" +fi + +text-generation-launcher --port 8080 diff --git a/server/.gitignore b/server/.gitignore new file mode 100644 index 00000000..576746ee --- /dev/null +++ b/server/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +text_generation_server/__pycache__/ +text_generation_server/pb/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +transformers +safetensors +flash-attention/ +flash-attention-v2/ +vllm/ +llm-awq/ +eetq/ +mamba/ diff --git a/server/Makefile b/server/Makefile new file mode 100644 index 00000000..32d01709 --- /dev/null +++ b/server/Makefile @@ -0,0 +1,30 @@ +include Makefile-flash-att +include Makefile-flash-att-v2 +include Makefile-vllm +include Makefile-awq +include Makefile-eetq +include Makefile-selective-scan + +unit-tests: + pytest -s -vv -m "not private" tests + +gen-server: + # Compile protos + pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir + mkdir text_generation_server/pb || true + python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \ + --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto + find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch text_generation_server/pb/__init__.py + +install: gen-server + pip install pip --upgrade + pip install -r requirements_cuda.txt + pip install -e ".[bnb, accelerate, quantize, peft, outlines]" + +run-dev: + SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded + +export-requirements: + poetry export -o requirements_cuda.txt --without-hashes + poetry export -o requirements_rocm.txt --without-hashes diff --git a/server/Makefile-awq b/server/Makefile-awq new file mode 100644 index 00000000..4e074a13 --- /dev/null +++ b/server/Makefile-awq @@ -0,0 +1,15 @@ +# Fork that adds only the correct stream to this kernel in order +# to make cuda graphs work. +awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4 + +awq: + rm -rf llm-awq + git clone https://github.com/huggingface/llm-awq + +build-awq: awq + cd llm-awq/ && git fetch && git checkout $(awq_commit) + cd llm-awq/awq/kernels && python setup.py build + +install-awq: build-awq + pip uninstall awq_inference_engine -y || true + cd llm-awq/awq/kernels && python setup.py install diff --git a/server/Makefile-eetq b/server/Makefile-eetq new file mode 100644 index 00000000..726e47b5 --- /dev/null +++ b/server/Makefile-eetq @@ -0,0 +1,13 @@ +eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0 + +eetq: + # Clone eetq + pip install packaging + git clone https://github.com/NetEase-FuXi/EETQ.git eetq + +build-eetq: eetq + cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive + cd eetq && python setup.py build + +install-eetq: build-eetq + cd eetq && python setup.py install diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att new file mode 100644 index 00000000..ffa304aa --- /dev/null +++ b/server/Makefile-flash-att @@ -0,0 +1,16 @@ +flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec + +flash-attention: + # Clone flash attention + pip install -U packaging ninja --no-cache-dir + git clone https://github.com/HazyResearch/flash-attention.git + +build-flash-attention: flash-attention + cd flash-attention && git fetch && git checkout $(flash_att_commit) + cd flash-attention && python setup.py build + cd flash-attention/csrc/rotary && python setup.py build + cd flash-attention/csrc/layer_norm && python setup.py build + +install-flash-attention: build-flash-attention + pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true + cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 new file mode 100644 index 00000000..bbff0090 --- /dev/null +++ b/server/Makefile-flash-att-v2 @@ -0,0 +1,29 @@ +flash_att_v2_commit_cuda := v2.5.8 +flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 + + +flash-attention-v2-cuda: + # Clone flash attention + pip install -U packaging ninja --no-cache-dir + git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 + +build-flash-attention-v2-cuda: flash-attention-v2-cuda + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) + cd flash-attention-v2 && git submodule update --init --recursive + cd flash-attention-v2 && python setup.py build + +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install + +flash-attention-v2-rocm: + # Clone flash attention + pip install -U packaging ninja --no-cache-dir + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 + +build-flash-attention-v2-rocm: flash-attention-v2-rocm + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) + cd flash-attention-v2 && git submodule update --init --recursive + cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build + +install-flash-attention-v2-rocm: build-flash-attention-v2-rocm + cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install diff --git a/server/Makefile-selective-scan b/server/Makefile-selective-scan new file mode 100644 index 00000000..b93b517d --- /dev/null +++ b/server/Makefile-selective-scan @@ -0,0 +1,28 @@ +selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 + +causal-conv1d: + rm -rf causal-conv1d + git clone https://github.com/Dao-AILab/causal-conv1d.git + +build-causal-conv1d: causal-conv1d + cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag + cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build + +install-causal-conv1d: build-causal-conv1d + pip uninstall causal-conv1d -y || true + cd causal-conv1d/ && pip install . + +# selective-scan dependends on causal-conv1d +selective-scan: + rm -rf mamba + git clone https://github.com/state-spaces/mamba.git mamba + +build-selective-scan: selective-scan + cd mamba/ && git fetch && git checkout $(selective_scan_commit) + cd mamba && python setup.py build + +install-selective-scan: install-causal-conv1d build-selective-scan + pip uninstall selective-scan-cuda -y || true + cd mamba && pip install . + +build-all: build-causal-conv1d build-selective-scan diff --git a/server/Makefile-vllm b/server/Makefile-vllm new file mode 100644 index 00000000..62fa413f --- /dev/null +++ b/server/Makefile-vllm @@ -0,0 +1,25 @@ +vllm-cuda: + # Clone vllm + pip install -U ninja packaging --no-cache-dir + git clone https://github.com/Narsil/vllm.git vllm + +build-vllm-cuda: vllm-cuda + cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa + cd vllm && python setup.py build + +install-vllm-cuda: build-vllm-cuda + pip uninstall vllm -y || true + cd vllm && python setup.py install + +vllm-rocm: + # Clone vllm + pip install -U ninja packaging --no-cache-dir + git clone https://github.com/fxmarty/rocm-vllm.git vllm + +build-vllm-rocm: vllm-rocm + cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 + cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install + +install-vllm-rocm: build-vllm-rocm + pip uninstall vllm -y || true + cd vllm && python setup.py install diff --git a/server/README.md b/server/README.md new file mode 100644 index 00000000..b8208f9e --- /dev/null +++ b/server/README.md @@ -0,0 +1,15 @@ +# Text Generation Inference Python gRPC Server + +A Python gRPC server for Text Generation Inference + +## Install + +```shell +make install +``` + +## Run + +```shell +make run-dev +``` diff --git a/server/custom_kernels/custom_kernels/fused_attention_cuda.cu b/server/custom_kernels/custom_kernels/fused_attention_cuda.cu new file mode 100644 index 00000000..60f9f028 --- /dev/null +++ b/server/custom_kernels/custom_kernels/fused_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor query, + const at::Tensor key, + const at::Tensor value, + const std::optional> layer_past, + const at::Tensor attention_mask, + const std::optional head_mask, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + auto query_layer = query; + auto key_layer = key; + auto value_layer = value; + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 2); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + const auto batch_size = query_layer.size(0); + const auto q_length = query_layer.size(2); + const auto attn_head_size = query_layer.size(3); + const auto batch_size_times_num_heads = batch_size * num_heads; + const auto kv_length = key_layer.size(2); + + const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size}); + auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2); + auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}); + + auto query_scaled = query_view * inv_norm_factor; + auto attention_scores = at::bmm(query_scaled, key_view); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_view); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "GPT-Neox attention mechanism forward (CUDA)" + ); +} diff --git a/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu b/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu new file mode 100644 index 00000000..8206c3e0 --- /dev/null +++ b/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor fused_qkv, + const std::optional> layer_past, + const at::Tensor alibi, + const at::Tensor attention_mask, + const std::optional head_mask, + const float beta, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + const auto batch_size = fused_qkv.size(0); + const auto q_length = fused_qkv.size(1); + const auto three_times_hidden_size = fused_qkv.size(2); + const auto head_dim = three_times_hidden_size / (3 * num_heads); + const auto batch_size_times_num_heads = batch_size * num_heads; + + // `split_heads` + const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim}); + const auto tensor_list = fused_qkv_view.split(head_dim, -1); + const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length}); + auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 1); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + const auto kv_length = key_layer.size(2); + + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_layer); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "Bloom attention mechanism forward (CUDA)" + ); +} diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py new file mode 100644 index 00000000..69f6b72a --- /dev/null +++ b/server/custom_kernels/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_compile_args = ["-std=c++17"] +if not torch.version.hip: + extra_compile_args.append("-arch=compute_80") + +setup( + name="custom_kernels", + ext_modules=[ + CUDAExtension( + name="custom_kernels.fused_bloom_attention_cuda", + sources=["custom_kernels/fused_bloom_attention_cuda.cu"], + extra_compile_args=extra_compile_args, + ), + CUDAExtension( + name="custom_kernels.fused_attention_cuda", + sources=["custom_kernels/fused_attention_cuda.cu"], + extra_compile_args=extra_compile_args, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/examples/test_cases.py b/server/examples/test_cases.py index 49c39db7..c7b8ec6b 100644 --- a/server/examples/test_cases.py +++ b/server/examples/test_cases.py @@ -7,12 +7,14 @@ import torch from typing import Optional + @dataclasses.dataclass class LoraSpec: lora_prompts: list[str] base_prompts: list[str] weight_path: Optional[str] = None + @dataclasses.dataclass class DemoSpec: weight_url: str @@ -142,6 +144,7 @@ def generate_prompts(self): """, ) +DEMO["Dogge/llama-3-70B-instruct-uncensored-lora"] = DEMO["tjluyao/llama-3-8b-oaast"] DEMO["abcdabcd987/gsm8k-llama2-7b-lora-16"] = DemoSpec( weight_url="https://huggingface.co/abcdabcd987/gsm8k-llama2-7b-lora-16/resolve/main/gsm8k-r16.punica.pt", @@ -277,3 +280,33 @@ def generate_prompts(self): {"question": "There are two warehouses. The first warehouse has twice as many boxes as the second warehouse. If the first warehouse has 400 boxes, how many boxes are there in both warehouses combined? Let's program in Python in the response.","answer": "# define the number of boxes in the first warehouse\nfirst_warehouse_boxes = 400\n# the first warehouse has twice as many boxes as the second warehouse\nsecond_warehouse_boxes = first_warehouse_boxes / 2\n# the total number of boxes is the sum of the boxes in both warehouses\ntotal_boxes = first_warehouse_boxes + second_warehouse_boxes\n# print the result\nprint(int(total_boxes))"} """, ) + +DEMO["REILX/Qwen1.5-7B-Chat-750Mb-lora"] = DemoSpec( + weight_url="https://huggingface.co/REILX/Qwen1.5-7B-Chat-750Mb-lora/blob/main/adapter_model.safetensors", + system="You are a helpful assistant.", + lora_template="user\n{question}\nmodel", + base_template="user\n{question}\nmodel", + example_template="user\n{question}\nmodel\n{answer}\n", + jsonl=r""" +{"question": "If Juan takes 14 seconds to run y yards, how many seconds will it take him to run x yards at the same rate? Answer Choices: (A) 14x/y (B) 14y/x (C) x/ (14y) (D) 14/ (xy) (E) xy/14","answer": "Let's solve the multi-choice question step by step.\nThis problem is testing us on the Rate x Time = Distance relationship. This relationship also tells us that Rate = Distance/Time and Time = Distance/Rate.\nUltimately, we are looking for how many seconds it will take Juan to run x yards. Thus, the equation we’ll use is: Time = Distance/Rate. We know the distance is x yards, and we need to find Juan’s rate.\nWe can find Juan’s rate as follows: Rate = Distance/Time = y yards/14 seconds\nUsing that rate, we need to determine how long it takes him to run x yards at the same rate. So we have:\nTime = Distance/Rate\nTime = x yards/(y yards/14 seconds)\nTime = (x yards) x (14 seconds/y yards)\nTime = 14x/y seconds\nThe answer is A."} +{"question": "Find out the number of ways in which 6 rings of different types can be worn in 3 fingers? Answer Choices: (A) 120 (B) 720 (C) 125 (D) 729 (E) None of these Please respond by writing a program in Python.","answer": "# The number of ways to arrange 6 rings on 3 fingers can be calculated using the formula for permutations with repetition: n^r, where n is the number of objects (rings) and r is the number of positions (fingers).\nnum_ways = 3**6\nprint(num_ways)"} +{"question": "0.01 x 0.02=? Answer Choices: (A) 3.15e-05 (B) 0.000315 (C) 0.2 (D) 0.02 (E) 0.0002","answer": "Let's think about the multi-choice question.\n1 x 2 = 2\nSum of decimal places =4\nTherefore, 0.01 x 0.02 = 0.0002\nThe answer is E."} +{"question": "Jordan read 120 French novels last holiday. His brother Alexandre read 1/10 of what Jordan read. How many more novels did Jordan read than Alexandre?","answer": "Alexandre read 120 * 1/10 = 12 novels.\nSo, Jordan read 120 - 12 = 108 novels more than Alexandre.\nThe answer is 108"} +{"question": "Every bedtime, Juwella reads a book. Three nights ago, she read 15 pages. Two nights ago she read twice that many pages, while last night she read 5 pages more than the previous night. She promised to read the remaining pages of the book tonight. If the book has 100 pages, how many pages will she read tonight?","answer": "Juwella read 15 x 2 = 30 pages two nights ago.\nShe read 30 + 5 = 35 pages last night.\nShe read a total of 15 + 30 + 35 = 80 pages for three nights.\nThus, she needs to read 100 - 80 = 20 pages tonight to finish the book.\nThe answer is 20"} +{"question": "A flagpole casts a shadow of 10 meters long. At the same time, the angle of elevation from the tip of the shadow to the top of the flagpole is 60 degrees. Calculate the height of the flagpole.","answer": "To calculate the height of the flagpole, we can use the tangent function in trigonometry. The tangent of an angle in a right triangle is the ratio of the length of the side opposite the angle to the length of the side adjacent to the angle.\n\nIn this case, the angle of elevation is 60 degrees, the side adjacent to the angle is the length of the shadow (10 meters), and the side opposite the angle is the height of the flagpole (which we want to find).\n\nWe can set up the equation using the tangent function:\n\ntan(60°) = height / 10\n\nThe tangent of 60 degrees is equal to the square root of 3 (approximately 1.732).\n\n1.732 = height / 10\n\nNow, we can solve for the height of the flagpole:\n\nheight = 1.732 * 10\nheight ≈ 17.32 meters\n\nThe height of the flagpole is approximately 17.32 meters."} +{"question": "Let $f(x) = 2x^4-17x^3+26x^2-24x-60$. Find $f(7)$. Let's write a program.","answer": "import sympy as sp\n\n# define the variable\nx = sp.symbols('x')\n\n# define the function\nf = 2*x**4 - 17*x**3 + 26*x**2 - 24*x - 60\n\n# calculate f(7)\nf_7 = f.subs(x, 7)\n\nprint(f_7)"} +{"question": "If the remainder is 12 when the integer n is divided by 22, what is the remainder when 2n is divided by 11? Answer Choices: (A) 0 (B) 2 (C) 3 (D) 6 (E) 7","answer": "Let's think about the multi-choice question.\nn = 22k+12\n2n = 2(22k+12) = 4*11k + 24 = 4*11k + 2*11 + 2 = 11j+2.\nThe answer is B."} +{"question": "A father said to his son, I was as old as you are at present at the time of your birth. If the father's age is 46 years now, the son's age five years back was?\nAnswer Choices: (A) 16 years (B) 14 years (C) 18 years (D) 19 years (E) 15 years","answer": "Let's solve the multi-choice question step by step.\nLet the son's present age be x years.\nThen, (46 - x) = x\n2x = 46 => x = 23\nSon's age 5 years back = (23 - 5)\n= 18 years.\nThe answer is C"} +{"question": "In a two-digit number, if it is known that its unit's digit exceeds its ten's digit by 3 and that the product of the given number and the sum of its digits is equal to 517, then the number is:\nAnswer Choices: (A) 14 (B) 36 (C) 47 (D) 58 (E) 63","answer": "Let's reason about the multi-choice question.\nLet the ten's digit be x. Then, unit's digit = x + 3. Number = 10x + (x + 3) = 11x + 3\n(11x + 3)(2x + 3) = 517\n22(x)(x)+ 39x- 508 = 0\n(x - 4)(22x + 127) = 0\nx = 4\nHence, required number = 11x + 3 = 47.\nThe answer is C"} +{"question": "In a batch of 30 apples, 1/6 are too small and 1/3 are not ripe. The others are perfect. Calculate the number of perfect apples.","answer": "Too small:30*1/6=5\nNot ripe:30*1/3=10\nPerfect:30-10-5=15\nThe answer is 15"} +{"question": "If a population of bacteria starts with 50 organisms and their logistic growth is defined by the differential equation dN/dt = 0.4N(1- N/100), where N is the population size and t is time in days, how long will it take for the population to reach a carrying capacity of 100 organisms?","answer": "To find the time it takes for the population to reach a carrying capacity of 100 organisms, we need to solve the logistic growth differential equation for N(t) and then find the value of t when N(t) = 100.\n\nThe logistic growth differential equation is given by:\n\ndN/dt = 0.4N(1 - N/100)\n\nFirst, we need to separate the variables:\n\ndN/N(1 - N/100) = 0.4 dt\n\nNow, we integrate both sides:\n\n∫(1/N(1 - N/100)) dN = ∫(0.4) dt\n\nTo solve the integral on the left side, we can use partial fraction decomposition. Let:\n\n1/N(1 - N/100) = A/N + B/(1 - N/100)\n\nMultiplying both sides by N(1 - N/100) to clear the denominators, we get:\n\n1 = A(1 - N/100) + B(N)\n\nNow, we can solve for A and B by choosing values for N that will eliminate one of the variables:\n\nIf N = 0, then:\n\n1 = A(1) + B(0) => A = 1\n\nIf N = 100, then:\n\n1 = A(0) + B(100) => B = 1/100\n\nNow we can rewrite the integral as:\n\n∫(1/N + 1/100(1 - N/100)) dN = ∫(0.4) dt\n\nIntegrating both sides, we get:\n\nln|N| - ln|1 - N/100| = 0.4t + C\n\nNow, we can use the initial condition N(0) = 50 to solve for the constant C:\n\nln|50| - ln|1 - 50/100| = 0.4(0) + C\n\nln|50| - ln|1/2| = C\n\nC = ln(100)\n\nNow, we can rewrite the equation as:\n\nln|N| - ln|1 - N/100| = 0.4t + ln(100)\n\nTo find the time it takes for the population to reach a carrying capacity of 100 organisms, we need to solve for t when N(t) = 100:\n\nln|100| - ln|1 - 100/100| = 0.4t + ln(100)\n\nln(100) - ln(0) = 0.4t + ln(100)\n\nSince ln(0) is undefined, we can conclude that the population will never exactly reach a carrying capacity of 100 organisms. However, it will approach 100 as time goes to infinity."} +{"question": "Robin is traveling from one end of a forest to the other. In order to find her way back, she is leaving morsels of bread in the following pattern: 2 morsels of RICE , 3 morsels of white, and 1 morsel of rye. The pattern repeats after she leaves the morsel of rye. If Robin drops 2,000 morsels of bread, what are the last 3 morsels of bread that she drops?\nAnswer Choices: (A) rye − wheat − wheat (B) RY-R-R (C) white − rye − wheat (D) white − white − white (E) white − white − rye","answer": "Let's think about the multi-choice question.\nQ is nothing BUT asking about remainder..\ntotal morsels thrown before the pattern is repeated = 2+3+1 = 6..\nso REMAINDER when 2000 is div by 6..\n1998 is EVEN and also div by 3, so remainder is 2..\nLast 3 of morsels are 1998-1999-2000\nso he uses the last of morsel of pattern in 1998 two of the first morsels of the pattern in 1999 and 2000, and the pattern is R, R, W, W, W, RY..\nans RY-R-R\nThe answe is B"} +{"question": "Great! Here's a numerical problem for the student to solve: At what point does the tangent line to the curve y = x^2 - 2x + 1 have a slope of 4?","answer": "To find the point where the tangent line to the curve y = x^2 - 2x + 1 has a slope of 4, we first need to find the derivative of the function with respect to x. The derivative represents the slope of the tangent line at any point on the curve.\n\nThe function is y = x^2 - 2x + 1.\n\nUsing the power rule for differentiation, we get:\n\ndy/dx = 2x - 2\n\nNow, we want to find the point where the slope (dy/dx) is equal to 4:\n\n4 = 2x - 2\n\nNow, we'll solve for x:\n\n4 + 2 = 2x\n6 = 2x\nx = 3\n\nNow that we have the x-coordinate, we can find the corresponding y-coordinate by plugging x back into the original function:\n\ny = (3)^2 - 2(3) + 1\ny = 9 - 6 + 1\ny = 4\n\nSo, the point where the tangent line to the curve y = x^2 - 2x + 1 has a slope of 4 is (3, 4)."} +{"question": "the average ( arithmetic mean ) of 24 , 35 , and 58 is 6 more than the average of 19 , 51 , and x . what is x ? Let's write a program.","answer": "n0 = 24.0\nn1 = 35.0\nn2 = 58.0\nn3 = 6.0\nn4 = 19.0\nn5 = 51.0\nt0 = n0 + n1\nt1 = n3 * 3.0\nt2 = n2 + t0\nt3 = n4 + t1\nt4 = n5 + t3\nanswer = t2 - t4\nprint(answer)"} +{"question": "the radius of a semi circle is 2.1 cm then its perimeter is ? Please write a program to solve it","answer": "import math\nn0 = 2.1\nt0 = 2 * math.pi * n0\nt1 = n0 * 2.0\nt2 = t0 / 2.0\nanswer = t2 + t1\nprint(answer)"} +{"question": "Zoe made a total of $8,000 cleaning pools and babysitting. She babysat Julie three times as often as she babysat Zachary. The number of times she babysat Zachary was 1/5 the number of times she babysat Chloe. If Zoe made $600 babysitting Zachary, how much did she earn from pool cleaning? Let's write a program.","answer": "# define the variables\ntotal_earning = 8000 # total earning\nearning_zachary = 600 # earning from Zachary\n\n# Since babysitting Zachary was 1/5 the number of times she babysat Chloe,\n# and she babysat Julie three times as often as she babysat Zachary,\n# so earning from Julie and Chloe can be expressed in terms of earning from Zachary\nearning_julie = 3 * earning_zachary\nearning_chloe = 5 * earning_zachary\n\n# calculate total earning from babysitting\ntotal_babysitting = earning_julie + earning_chloe + earning_zachary\n\n# calculate earning from pool cleaning\nearning_pool_cleaning = total_earning - total_babysitting\n\nprint(earning_pool_cleaning)"} +{"question": "Jill bought 5 packs of red bouncy balls and 4 packs of yellow bouncy balls. Each package contained 18 bouncy balls. How many more red bouncy balls than yellow bouncy balls did Jill buy?","answer": "Jill bought 5 packs * 18 red balls = 90 red balls.\nJill bought 4 packs * 18 yellow balls = 72 red balls.\nJill had 90 red – 72 yellow = 18 more red balls.\nThe answer is 18"} +{"question": "A box contains nine bulbs out of which 4 are defective. If four bulbs are chosen at random, find the probability that exactly three bulbs are good? Answer Choices: (A) 20/69 (B) 20/63 (C) 20/62 (D) 20/29 (E) 20/61","answer": "Let's think about the multi-choice question.\nRequired probability\n= (⁵C₃ . ⁴C₁)/⁹C₄\n= (10 * 4)/126\n= 20/63\nThe answer is B"} +{"question": "There are two warehouses. The first warehouse has twice as many boxes as the second warehouse. If the first warehouse has 400 boxes, how many boxes are there in both warehouses combined? Let's program in Python in the response.","answer": "# define the number of boxes in the first warehouse\nfirst_warehouse_boxes = 400\n# the first warehouse has twice as many boxes as the second warehouse\nsecond_warehouse_boxes = first_warehouse_boxes / 2\n# the total number of boxes is the sum of the boxes in both warehouses\ntotal_boxes = first_warehouse_boxes + second_warehouse_boxes\n# print the result\nprint(int(total_boxes))"} +""", +) diff --git a/server/examples/test_causal.py b/server/examples/test_causal.py new file mode 100644 index 00000000..9d5d7979 --- /dev/null +++ b/server/examples/test_causal.py @@ -0,0 +1,97 @@ +from text_generation_server.pb import generate_pb2 +import torch +from text_generation_server.models.flashinfer_causal_lm import ( + FlashinferLM, + FlashinferBatch, +) +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch +import random, json +from test_cases import DEMO, LoraSpec + +# Load demo inputs +lora_specs = {} +for name, spec in DEMO.items(): + lora_prompts, base_prompts = spec.generate_prompts() + lora_specs[name] = LoraSpec(lora_prompts, base_prompts) + + +# Create input requests +def make_input(lora_id, lora_or_base, id=0, promptOverride=None): + if lora_or_base == "lora": + prompts = lora_specs[lora_id].lora_prompts + elif lora_or_base == "base" or lora_or_base == "empty": + prompts = lora_specs[lora_id].base_prompts + lora_id = "empty" + else: + raise ValueError(f"Unknown lora_or_base={lora_or_base}") + prompt = random.choice(prompts) if not promptOverride else promptOverride + inputs = prompt + + request = generate_pb2.Request( + id=id, + inputs=inputs, + truncate=256, + prefill_logprobs=True, + top_n_tokens=20, + parameters=generate_pb2.NextTokenChooserParameters( + temperature=0.9, + top_k=10, + top_p=0.9, + typical_p=0.9, + repetition_penalty=1.1, + ), + stopping_parameters=generate_pb2.StoppingCriteriaParameters( + max_new_tokens=256, stop_sequences=[], ignore_eos_token=True + ), + lora_id=lora_id, + ) + return request + + +flash = False + +if flash: + service = FlashinferLM( + model_type="llama", model_id="meta-llama/Llama-2-7b-hf", lora_ids=["empty"] + ) +else: + service = CausalLM(model_id="meta-llama/Llama-2-7b-hf") +requests = [ + make_input("abcdabcd987/gsm8k-llama2-7b-lora-16", "base", id=0) +] # , promptOverride= "test")] + +tokenizer = service.tokenizer +batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests)) +if flash: + pb_batch = FlashinferBatch.from_pb( + batch, tokenizer, torch.float16, torch.device("cuda") + ) + ids = service.add_request(pb_batch) +else: + pb_batch = CausalLMBatch.from_pb( + batch, tokenizer, torch.float16, torch.device("cuda") + ) + +display_results = {} + +# service.warmup(pb_batch) + +while True: + if flash: + generations, _, _ = service.generate_token(FlashinferBatch.Empty(batch.id)) + else: + generations, _, _ = service.generate_token(pb_batch) + for gen in generations: + if gen.generated_text: + display_results[gen.request_id] = [ + "Prompt: " + + tokenizer.decode(gen.prefill_tokens.token_ids) + + "\nAnswer: " + + gen.generated_text.text + ] + if all([g.generated_text for g in generations]): + break + +for id in display_results: + print(str(id) + "=" * 30) + print("".join(display_results[id])) diff --git a/server/examples/test_llava.py b/server/examples/test_llava.py index 45b3e649..9df401d9 100644 --- a/server/examples/test_llava.py +++ b/server/examples/test_llava.py @@ -9,27 +9,30 @@ tokenizer = model.tokenizer prompts = [ - 'How many people are in the image?', - 'What is the main object in the image?', - 'What is the mood of the image?', - 'What is the setting of the image?', - 'What is the image about?', + "How many people are in the image?", + "What is the main object in the image?", + "What is the mood of the image?", + "What is the setting of the image?", + "What is the image about?", ] + def load_img_base64s(img_path): with open(img_path, "rb") as image_file: img_encoded = base64.b64encode(image_file.read()) return img_encoded - + + def get_input(prompt): - input = 'USER: '+ prompt + ' ASSISTANT: ' + input = "USER: " + prompt + " ASSISTANT: " return input -def make_input(jpg_path, id = 0): + +def make_input(jpg_path, id=0): prompt = random.choice(prompts) request = generate_pb2.Request( inputs=get_input(prompt), - inputb = load_img_base64s(jpg_path), + inputb=load_img_base64s(jpg_path), lora_id=None, id=id, truncate=1024, @@ -45,16 +48,18 @@ def make_input(jpg_path, id = 0): repetition_penalty=1.0, frequency_penalty=0.1, watermark=True, - grammar='', - grammar_type=0), + grammar="", + grammar_type=0, + ), stopping_parameters=generate_pb2.StoppingCriteriaParameters( - max_new_tokens=1024, - stop_sequences=[], - ignore_eos_token=True)) + max_new_tokens=1024, stop_sequences=[], ignore_eos_token=True + ), + ) return request -requests = [make_input('test.jpg') for _ in range(5)] -batch = generate_pb2.Batch(id = 0, requests = requests, size = len(requests)) + +requests = [make_input("test.jpg") for _ in range(5)] +batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests)) pb_batch = LlavaBatch.from_pb(batch, tokenizer, torch.float16, torch.device("cuda")) results = [] @@ -64,4 +69,4 @@ def make_input(jpg_path, id = 0): if gen.generated_text is not None: results.append(gen.generated_text.text) -print(results) \ No newline at end of file +print(results) diff --git a/server/examples/test_local_api.py b/server/examples/test_local_api.py index e63e2ad2..43756946 100644 --- a/server/examples/test_local_api.py +++ b/server/examples/test_local_api.py @@ -1,15 +1,52 @@ from text_generation_server.pb import generate_pb2 import torch -from text_generation_server.models.flashinfer_causal_lm import FlashinferLM, FlashinferBatch +from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama +from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma +from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2 +from text_generation_server.models_flashinfer.flashinfer_chatglm import ( + FlashinferChatGLM, +) +import sys + +try: + from text_generation_server.models_flashinfer.flashinfer_mistral import ( + FlashinferMistral, + ) + from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi + from text_generation_server.models_flashinfer.flashinfer_qwen2 import ( + FlashinferQwen2, + ) +except: + print("can't load flashinfer mistral and phi and qwen2 without flash attn") + +from text_generation_server.models_flashinfer.flashinfer_causal_lm import ( + FlashinferBatch, +) import random, json from test_cases import DEMO, LoraSpec +if len(sys.argv) == 2: + test = sys.argv[1] +else: + # test = "gemma" + # test = "llama-3" + # test = 'llama-3-70' + test = "gemma" + # test = 'mistral' + # test = 'qwen1.5-7' + # test = 'qwen1.5-1.8' + # test = 'qwen1.5-70' + # test = 'qwen2-7' + # test = "chatglm4" +print("Testing " + test) + # Load demo inputs lora_specs = {} for name, spec in DEMO.items(): lora_prompts, base_prompts = spec.generate_prompts() lora_specs[name] = LoraSpec(lora_prompts, base_prompts) + # Create input requests def make_input(lora_id, lora_or_base, id=0, promptOverride=None): if lora_or_base == "lora": @@ -36,84 +73,251 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): repetition_penalty=1.1, ), stopping_parameters=generate_pb2.StoppingCriteriaParameters( - max_new_tokens=2048, - stop_sequences=[], - ignore_eos_token=True), - lora_id=lora_id + max_new_tokens=2048, stop_sequences=[], ignore_eos_token=True + ), + lora_id=lora_id, ) return request -test = 'gemma' -# test = 'llama-3' -# test = 'llama-2' -# test = 'mistral' -if test == 'llama-2': +if test == "llama-2": # Load model - service = FlashinferLM(model_type="llama", model_id="meta-llama/Llama-2-7b-hf", - lora_ids=['abcdabcd987/gsm8k-llama2-7b-lora-16']) + service = FlashinferLlama( + model_id="meta-llama/Llama-2-7b-hf", + lora_ids=["abcdabcd987/gsm8k-llama2-7b-lora-16"], + ) # Create an input batch of two queries - requests = [make_input('abcdabcd987/gsm8k-llama2-7b-lora-16', 'base', id=0, promptOverride= "Give me a breif introduction to Byznatine Fault Tolerance and why it is important?"), - make_input('abcdabcd987/gsm8k-llama2-7b-lora-16', 'lora', id=1, promptOverride="Which network interface card is more suitable for distributed systems, Meallanox or Broadcom?")] -elif test == 'llama-3': + requests = [ + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="Give me a breif introduction to Byznatine Fault Tolerance and why it is important?", + ), + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "lora", + id=1, + promptOverride="Which network interface card is more suitable for distributed systems, Meallanox or Broadcom?", + ), + ] +elif test == "llama-3": # Load model - service = FlashinferLM(model_type="llama", model_id="tjluyao/llama-3-8b", - lora_ids=['tjluyao/llama-3-8b-math', - 'tjluyao/llama-3-8b-zh']) + service = FlashinferLlama( + model_id="tjluyao/llama-3-8b", + lora_ids=["tjluyao/llama-3-8b-math", "tjluyao/llama-3-8b-zh"], + ) # Test load lora adapters print(service.get_lora_adapters()) # Test remove lora adapters - service.remove_lora_adapters(['llama3-zh']) + service.remove_lora_adapters(["llama3-zh"]) print(service.get_lora_adapters()) service.remove_lora_adapters() print(service.get_lora_adapters()) - service.load_lora_adapters(['tjluyao/llama-3-8b-math', - 'tjluyao/llama-3-8b-oaast', - 'tjluyao/llama-3-8b-zh']) + service.load_lora_adapters( + ["tjluyao/llama-3-8b-math", "tjluyao/llama-3-8b-oaast", "tjluyao/llama-3-8b-zh"] + ) # Create an input batch of two queries - requests = [make_input('tjluyao/llama-3-8b-zh', 'lora', id=0), - make_input('tjluyao/llama-3-8b-oaast', 'lora', id=1), - make_input('tjluyao/llama-3-8b-zh', 'empty', id=2)] -elif test == "gemma": - requests = [make_input("tjluyao/gemma-2b-it-math", "base", id=0), - make_input("tjluyao/gemma-2b-it-math", "lora", id=1), - make_input("monsterapi/gemma-2b-lora-maths-orca-200k", "lora", id=2)] - service = FlashinferLM(model_type="gemma", model_id="google/gemma-2b-it", - lora_ids=['tjluyao/gemma-2b-it-math', - 'monsterapi/gemma-2b-lora-maths-orca-200k']) - # service = FlashinferLM(model_type="gemma", model_id="google/gemma-2b", lora_ids=['tjluyao/gemma-2b-math']) - # service = FlashinferLM(model_type="gemma", model_id="google/gemma-2b", lora_ids=[]) + requests = [ + make_input("tjluyao/llama-3-8b-zh", "lora", id=0), + make_input("tjluyao/llama-3-8b-oaast", "lora", id=1), + make_input("tjluyao/llama-3-8b-zh", "empty", id=2), + ] +elif test == "llama-3-70": + # Load model + service = FlashinferLlama( + model_id="TechxGenus/Meta-Llama-3-70B-Instruct-AWQ", + lora_ids=["Dogge/llama-3-70B-instruct-uncensored-lora"], + quantize="AWQ", + ) + # service = FlashinferLlama(model_id="TechxGenus/Meta-Llama-3-70B-Instruct-GPTQ", + # lora_ids=['Dogge/llama-3-70B-instruct-uncensored-lora'], quantize='GPTQ') + # Create an input batch of two queries + requests = [make_input("Dogge/llama-3-70B-instruct-uncensored-lora", "lora", id=0)] +elif test == "gemma": + requests = [ + make_input("tjluyao/gemma-2b-it-math", "lora", id=0), + make_input("tjluyao/gemma-2b-it-math", "lora", id=1), + make_input("tjluyao/gemma-2b-it-math", "lora", id=2), + ] + service = FlashinferGemma( + model_id="google/gemma-2b-it", + lora_ids=[ + "tjluyao/gemma-2b-it-math", + "monsterapi/gemma-2b-lora-maths-orca-200k", + ], + ) + # service = FlashinferGemma(model_id="google/gemma-2b", lora_ids=['tjluyao/gemma-2b-math']) + # service = FlashinferGemma(model_id="google/gemma-2b", lora_ids=[]) # Quantized version - # service = FlashinferLM(model_type="gemma", model_id="TechxGenus/gemma-2b-GPTQ", quantize='gptq') + # service = FlashinferGemma(model_id="TechxGenus/gemma-2b-GPTQ", quantize='gptq') elif test == "mistral": - requests = [make_input("abcdabcd987/gsm8k-llama2-7b-lora-16", "base", id=0, promptOverride="why is deep learning so popular these days?"), - make_input("abcdabcd987/gsm8k-llama2-7b-lora-16", "base", id=1, promptOverride="What are the differences between Manhattan and Brooklyn")] - service = FlashinferLM(model_type="mistral", model_id="mistralai/Mistral-7B-v0.3") + requests = [ + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="why is deep learning so popular these days?", + ), + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=1, + promptOverride="What are the differences between Manhattan and Brooklyn", + ), + ] + service = FlashinferMistral(model_id="mistralai/Mistral-7B-v0.3") +elif test == "qwen1.5-7": + requests = [ + make_input( + "REILX/Qwen1.5-7B-Chat-750Mb-lora", + "base", + id=0, + promptOverride="给我讲个故事", + ), + make_input( + "REILX/Qwen1.5-7B-Chat-750Mb-lora", + "lora", + id=1, + promptOverride="什么是深度学习?", + ), + ] + + service = FlashinferQwen2( + model_id="Qwen/Qwen1.5-7B-Chat", lora_ids=["REILX/Qwen1.5-7B-Chat-750Mb-lora"] + ) +elif test == "qwen1.5-1.8": + # Todo: Add qwen1.5 1.8b chat lora adapter / Output Repetition Problem + requests = [ + make_input( + "REILX/Qwen1.5-7B-Chat-750Mb-lora", + "base", + id=0, + promptOverride="给我讲个故事", + ) + ] + + service = FlashinferQwen2( + model_id="Qwen/Qwen1.5-1.8B-Chat", lora_ids=["REILX/Qwen1.5-7B-Chat-750Mb-lora"] + ) +elif test == "qwen1.5-70": + # Todo: Add qwen1.5 72b chat lora adapter + requests = [ + make_input( + "REILX/Qwen1.5-7B-Chat-750Mb-lora", + "base", + id=0, + promptOverride="给我讲个故事", + ) + ] + + service = FlashinferQwen2( + model_id="Qwen/Qwen1.5-72B-Chat-GPTQ-Int4", + lora_ids=["REILX/Qwen1.5-7B-Chat-750Mb-lora"], + quantize="gptq", + ) +elif test == "phi": + requests = [ + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="why is deep learning so popular these days?", + ), + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=1, + promptOverride="What are the differences between Manhattan and Brooklyn", + ), + ] + service = FlashinferPhi(model_id="microsoft/phi-2") +elif test == "phi3": + requests = [ + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="why is deep learning so popular these days?", + ), + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=1, + promptOverride="What are the differences between Manhattan and Brooklyn", + ), + ] + service = FlashinferLlama(model_id="microsoft/Phi-3-mini-4k-instruct") +elif test == "baichuan": + requests = [ + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="why is deep learning so popular these days?", + ), + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=1, + promptOverride="What are the differences between Manhattan and Brooklyn", + ), + ] + service = FlashinferLlama( + model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True + ) +elif test == "qwen2-7": + # Todo: qwen2-7b instruct lora adapter + requests = [ + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="给我讲个故事", + ), + ] + service = FlashinferQwen2(model_id="Qwen/Qwen2-7B-Instruct", trust_remote_code=True) + +elif test == "chatglm4": + # Todo: chatglm4-9b lora adapter + requests = [ + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="给我讲个故事", + ), + ] + service = FlashinferChatGLM(model_id="THUDM/glm-4-9b-chat", trust_remote_code=True) print(service.get_lora_adapters()) tokenizer = service.tokenizer -batch = generate_pb2.Batch(id = 0, requests = requests, size = len(requests)) -pb_batch = FlashinferBatch.from_pb(batch, tokenizer, torch.float16, torch.device("cuda")) - -# Add input batch to model service -ids = service.add_request(pb_batch) +batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests)) display_results = {} # Iterative generation: each step generates a token for each input in the batch +isPrefill = True while True: - # When calling iterative text generation, we may add new inputs (use pb_batch like above) - # or use an empty batch (use EmptyFlashinferBatch) - generations, _, _ = service.generate_token(FlashinferBatch.Empty(batch.id)) + if isPrefill: + generations, next_batch, _ = service.prefill_batch(batch) + isPrefill = False + else: + generations, next_batch, _, _ = service.decode_batch([next_batch.to_pb()]) + + for gen in generations: + if gen.prefill_tokens: + display_results[gen.request_id] = [ + "Prompt:\n" + + tokenizer.decode(gen.prefill_tokens.token_ids) + + "\nAnswer:\n" + ] + if gen.generated_text: + display_results[gen.request_id] += [gen.generated_text.text] # Stop if all input generations are done - if not generations: + if all([g.generated_text for g in generations]): break - for gen in generations: - if gen.request_id in display_results: - display_results[gen.request_id].append(gen.tokens.texts[0]) - else: - display_results[gen.request_id] = [gen.tokens.texts[0]] for id in display_results: - print(str(id) + '='*30) - print(''.join(display_results[id])) \ No newline at end of file + print(str(id) + "=" * 30) + print("".join(display_results[id])) diff --git a/server/examples/test_local_grpc.py b/server/examples/test_local_grpc.py index b46f20e5..8b92865c 100644 --- a/server/examples/test_local_grpc.py +++ b/server/examples/test_local_grpc.py @@ -12,7 +12,8 @@ lora_prompts, base_prompts = spec.generate_prompts() lora_specs[name] = LoraSpec(lora_prompts, base_prompts) -def make_input(lora_id, lora_or_base): + +def make_input(lora_id, lora_or_base, id=0, promptOverride=None): if lora_or_base == "lora": prompts = lora_specs[lora_id].lora_prompts elif lora_or_base == "base" or lora_or_base == "empty": @@ -20,10 +21,11 @@ def make_input(lora_id, lora_or_base): lora_id = "empty" else: raise ValueError(f"Unknown lora_or_base={lora_or_base}") - prompt = random.choice(prompts) + prompt = random.choice(prompts) if not promptOverride else promptOverride inputs = prompt request = generate_pb2.Request( + id=id, inputs=inputs, truncate=256, prefill_logprobs=True, @@ -36,66 +38,56 @@ def make_input(lora_id, lora_or_base): repetition_penalty=1.1, ), stopping_parameters=generate_pb2.StoppingCriteriaParameters( - max_new_tokens=2048, - stop_sequences=[], - ignore_eos_token=True), - lora_id=lora_id + max_new_tokens=2048, stop_sequences=[], ignore_eos_token=True + ), + lora_id=lora_id, ) return request -req1 = make_input('gsm8k', 'base') -req2 = make_input('gsm8k', 'lora') -requests = [req1, req2] + +requests = [ + make_input("tjluyao/gemma-2b-it-math", "base", id=0), + make_input("tjluyao/gemma-2b-it-math", "base", id=1), +] # Assemble input batch -pb_batch_with_inputs = generate_pb2.Batch(id = 0, requests = requests, size = len(requests)) +pb_batch_with_inputs = generate_pb2.Batch(id=0, requests=requests, size=len(requests)) pb_batch_empty = generate_pb2.Batch() with grpc.insecure_channel("unix:///tmp/text-generation-server-0") as channel: stub = generate_pb2_grpc.TextGenerationServiceStub(channel) - # Test adapter loading and offloading - stub.AdapterControl(generate_pb2.AdapterControlRequest( - lora_ids='all', - operation='remove' - )) - stub.AdapterControl(generate_pb2.AdapterControlRequest( - lora_ids='gsm8k:abcdabcd987/gsm8k-llama2-7b-lora-16,sqlctx:abcdabcd987/sqlctx-llama2-7b-lora-16,viggo:abcdabcd987/viggo-llama2-7b-lora-16', - operation='load' - )) - resp = stub.AdapterControl(generate_pb2.AdapterControlRequest( - operation='status' - )) - print(resp) - # Info print(stub.Info(generate_pb2.InfoRequest())) # Warm up - wr = generate_pb2.WarmupRequest(batch = pb_batch_with_inputs, max_total_tokens = 2048, max_prefill_tokens = 1024*10, max_input_length = 1024) + wr = generate_pb2.WarmupRequest( + batch=pb_batch_with_inputs, + max_total_tokens=2048, + max_prefill_tokens=1024 * 10, + max_input_length=1024, + ) stub.Warmup(wr) # Prefill - pr = generate_pb2.PrefillRequest(batch = pb_batch_empty) + pr = generate_pb2.PrefillRequest(batch=pb_batch_with_inputs) resp = stub.Prefill(pr) - gen, cbatch = resp.generations, resp.batch - # Decode - dr = generate_pb2.DecodeRequest(batches = [cbatch]) - resp = stub.Decode(dr) - gen, cbatch = resp.generations, resp.batch + generations, cbatch = resp.generations, resp.batch + for gen in generations: + print(gen.tokens.texts) + + print("finished prefill tokens") - results = {} - # Generate token - pr = generate_pb2.GenerateTokenRequest(batch = pb_batch_empty) while True: - resp = stub.GenerateToken(pr) + dr = generate_pb2.DecodeRequest(batches=[cbatch]) + resp = stub.Decode(dr) generations, cbatch = resp.generations, resp.batch - if not generations: - break + toExit = False for gen in generations: - if gen.request_id in results: - results[gen.request_id].append(gen.tokens.texts[0]) - else: - results[gen.request_id] = [gen.tokens.texts[0]] - for id in results: - print(str(id) + '=' * 30) - print(''.join(results[id])) - print('done') \ No newline at end of file + if gen.generated_text.text: + print("finished") + res = gen.generated_text.text + toExit = True + + if toExit: + break + + print(res) diff --git a/server/examples/test_ui.py b/server/examples/test_ui.py index 61aa31aa..35fd5d2e 100644 --- a/server/examples/test_ui.py +++ b/server/examples/test_ui.py @@ -1,4 +1,7 @@ -from text_generation_server.models.flashinfer_causal_lm import FlashinferLM, FlashinferBatch +from text_generation_server.models.flashinfer_causal_lm import ( + FlashinferLM, + FlashinferBatch, +) from text_generation_server.pb import generate_pb2_grpc, generate_pb2 @@ -25,8 +28,11 @@ "ignore", category=UserWarning, message="TypedStorage is deprecated" ) + class MultiLora: - def __init__(self, base_model, model_type, lora_ids, lora_specs: dict[str, LoraSpec]): + def __init__( + self, base_model, model_type, lora_ids, lora_specs: dict[str, LoraSpec] + ): self.device = torch.device("cuda:0") self.lora_specs = lora_specs self.stop_signal = threading.Event() @@ -67,16 +73,17 @@ def _create_request(self, lora_id: str, lora_or_base: str): repetition_penalty=1.1, ), stopping_parameters=generate_pb2.StoppingCriteriaParameters( - max_new_tokens=2048, - stop_sequences=[], - ignore_eos_token=True), - lora_id=lora_id + max_new_tokens=2048, stop_sequences=[], ignore_eos_token=True + ), + lora_id=lora_id, + ) + batch = generate_pb2.Batch(id=0, requests=[request], size=1) + pb_batch = FlashinferBatch.from_pb( + batch, self.tokenizer, torch.float16, torch.device("cuda") ) - batch = generate_pb2.Batch(id = 0, requests = [request], size = 1) - pb_batch = FlashinferBatch.from_pb(batch, self.tokenizer, torch.float16, torch.device("cuda")) self.model.add_request(pb_batch) - self.id+=1 - return self.id-1, prompt + self.id += 1 + return self.id - 1, prompt def stop(self): self.stop_signal.set() @@ -87,18 +94,20 @@ def run( ): time.sleep(0.1) while not self.stop_signal.is_set(): - generations, _, timing = self.model.generate_token(FlashinferBatch.from_pb(generate_pb2.Batch())) + generations, _, timing = self.model.generate_token( + FlashinferBatch.from_pb(generate_pb2.Batch()) + ) for gen in generations: - append_box('-'.join(self.reqname[gen.request_id]), gen.tokens.texts[0]) + append_box("-".join(self.reqname[gen.request_id]), gen.tokens.texts[0]) for gen in generations: - if gen.generated_text: # finished - append_box('-'.join(self.reqname[gen.request_id]), "\n------\n\n") + if gen.generated_text: # finished + append_box("-".join(self.reqname[gen.request_id]), "\n------\n\n") model_name, lora_or_base = self.reqname[gen.request_id] nid, prompt = self._create_request(model_name, lora_or_base) self.reqname[nid] = (model_name, lora_or_base) - append_box('-'.join(self.reqname[nid]), prompt) - assert(len(self.model.reqctx) == 6) + append_box("-".join(self.reqname[nid]), prompt) + assert len(self.model.reqctx) == 6 class TailLog(Label): @@ -117,6 +126,7 @@ def write(self, append: str): self._last_line_text = last_line.plain.rstrip() self.update(Lines(self._lines + [last_line])) + class MultiLoraTui(App): CSS = """ .box { @@ -137,7 +147,7 @@ class MultiLoraTui(App): class AppendBox(Message): def __init__(self, box_id: str, text: str): super().__init__() - self.box_id = box_id.replace('/', '--') + self.box_id = box_id.replace("/", "--") self.text = text def __init__(self, model_names: list[str]): @@ -148,7 +158,7 @@ def compose(self) -> ComposeResult: yield Header() with Vertical(): for model_name in self._model_names: - model_name = model_name.replace('/', '--') + model_name = model_name.replace("/", "--") with Horizontal(): box_lora = TailLog(id=f"{model_name}-lora", classes="box") box_lora.border_title = f"{model_name}: LoRA finetuned model" @@ -163,7 +173,8 @@ def compose(self) -> ComposeResult: def on_multi_lora_tui_append_box(self, msg: AppendBox): self.query_one(f"#{msg.box_id}").write(msg.text) -if __name__ == '__main__': + +if __name__ == "__main__": project_root = pathlib.Path(__file__).parents[1] model_dir = project_root / "model" @@ -174,9 +185,11 @@ def on_multi_lora_tui_append_box(self, msg: AppendBox): # 'abcdabcd987/viggo-llama2-7b-lora-16'] base_model = "tjluyao/llama-3-8b" model_type = "llama" - lora_ids = ['tjluyao/llama-3-8b-math', - 'tjluyao/llama-3-8b-oaast', - 'tjluyao/llama-3-8b-zh'] + lora_ids = [ + "tjluyao/llama-3-8b-math", + "tjluyao/llama-3-8b-oaast", + "tjluyao/llama-3-8b-zh", + ] lora_specs = {} for name, spec in DEMO.items(): @@ -197,4 +210,4 @@ def append_box(box_id, text): thread.start() tui.run() logic.stop() - thread.join() \ No newline at end of file + thread.join() diff --git a/server/exllama_kernels/exllama_kernels/cu_compat.cuh b/server/exllama_kernels/exllama_kernels/cu_compat.cuh new file mode 100644 index 00000000..c5258813 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_buffers.cu b/server/exllama_kernels/exllama_kernels/cuda_buffers.cu new file mode 100644 index 00000000..ee2cbee2 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_buffers.cu @@ -0,0 +1,71 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh b/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh new file mode 100644 index 00000000..afb60a01 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh @@ -0,0 +1,52 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu new file mode 100644 index 00000000..c25b0206 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu @@ -0,0 +1,61 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "../util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh new file mode 100644 index 00000000..0364e38c --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu new file mode 100644 index 00000000..1b0f7956 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -0,0 +1,256 @@ +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" +#include "../cu_compat.cuh" +#include "../cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "../hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + bool no_zero, + const cublasHandle_t handle +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); + +// const float alpha = 1.0f; +// const float beta = no_zero ? 1.0f : 0.0f; +// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, +// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh new file mode 100644 index 00000000..4c7a6669 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh @@ -0,0 +1,37 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "../tuning.h" + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + bool no_zero, + const cublasHandle_t handle +); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu new file mode 100644 index 00000000..1f32e6b8 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -0,0 +1,220 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include "q4_matrix.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh new file mode 100644 index 00000000..49431dc9 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/exllama_ext.cpp b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp new file mode 100644 index 00000000..f2df80e8 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp @@ -0,0 +1,253 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "cuda_func/q4_matrix.cuh" +#include "cuda_func/q4_matmul.cuh" +#include "cuda_func/column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + false, + stream + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + false, + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/server/exllama_kernels/exllama_kernels/hip_compat.cuh b/server/exllama_kernels/exllama_kernels/hip_compat.cuh new file mode 100644 index 00000000..f2a3dcad --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/hip_compat.cuh @@ -0,0 +1,52 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _hip_compat_cuh +#define _hip_compat_cuh + +// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{ + _Float16_2{static_cast<_Float16>(1.0f), + static_cast<_Float16>(1.0f)} / x.data}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf. +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_handle hipblasHandle_t +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream +#define rocblas_hgemm __compat_hipblasHgemm + +#endif diff --git a/server/exllama_kernels/exllama_kernels/matrix.cuh b/server/exllama_kernels/exllama_kernels/matrix.cuh new file mode 100644 index 00000000..2fd5ab0b --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/server/exllama_kernels/exllama_kernels/tuning.h b/server/exllama_kernels/exllama_kernels/tuning.h new file mode 100644 index 00000000..770ca46a --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/server/exllama_kernels/exllama_kernels/util.cuh b/server/exllama_kernels/exllama_kernels/util.cuh new file mode 100644 index 00000000..7b397573 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/server/exllama_kernels/setup.py b/server/exllama_kernels/setup.py new file mode 100644 index 00000000..987d181e --- /dev/null +++ b/server/exllama_kernels/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="exllama_kernels", + ext_modules=[ + CUDAExtension( + name="exllama_kernels", + sources=[ + "exllama_kernels/exllama_ext.cpp", + "exllama_kernels/cuda_buffers.cu", + "exllama_kernels/cuda_func/column_remap.cu", + "exllama_kernels/cuda_func/q4_matmul.cu", + "exllama_kernels/cuda_func/q4_matrix.cu", + ], + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/exllamav2_kernels/exllamav2_kernels/config.h b/server/exllamav2_kernels/exllamav2_kernels/config.h new file mode 100644 index 00000000..32a1a37d --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/config.h @@ -0,0 +1,15 @@ +#ifndef _config_h +#define _config_h + +#define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS + +#define QMODE_2BIT 1 +#define QMODE_3BIT 1 +#define QMODE_4BIT 1 +#define QMODE_5BIT 1 +#define QMODE_6BIT 0 +#define QMODE_8BIT 0 + + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h b/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h new file mode 100644 index 00000000..919703a8 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cpp/util.h @@ -0,0 +1,12 @@ +#ifndef _util_h +#define _util_h + +#define DBGS(__x) printf("%s\n", __x) +#define DBGI(__x) printf("%s: %i\n", #__x, __x) +#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) +#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF(__x) printf("%s: %f\n", #__x, __x) +#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) +#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh new file mode 100644 index 00000000..12684ff8 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh @@ -0,0 +1,56 @@ +#ifndef _compat_cuh +#define _compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh new file mode 100644 index 00000000..a72bc7bc --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh @@ -0,0 +1,121 @@ +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "quant/qdq_util.cuh" + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu new file mode 100644 index 00000000..5b99f1ba --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -0,0 +1,220 @@ +#include "q_gemm.cuh" +#include "util.cuh" +#include "matrix_view.cuh" +#include "../config.h" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define GPTQ_BLOCK_KN_SIZE 128 +#define GPTQ_BLOCK_M_SIZE_MAX 8 +#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32) + +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) + +#define CLEAR_N_SIZE 256 + +#include "q_gemm_kernel.cuh" +#include "q_gemm_kernel_gptq.cuh" + +void gemm_half_q_half_cuda_part +( + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + bool clear, + const half* r_weights, + int r_weights_stride, + bool mul_r_weights +) +{ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (!b->is_gptq) + { + dim3 blockDim, gridDim; + blockDim.x = EXL2_BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE); + + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); + + kernel<<>> + ( + a, + b->cuda_q_weight, + b->cuda_q_scale, + b->cuda_q_scale_max, + c, + size_m, + size_n, + size_k, + b->groups, + b->cuda_q_group_map, + b->cuda_q_perm, + b->rows_8, + b->rows_6, + b->rows_5, + b->rows_4, + b->rows_3, + b->rows_2, + clear, + r_weights, + r_weights_stride + ); + } + else + { + dim3 blockDim, gridDim; + blockDim.x = GPTQ_BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights); + +// DBGX((uint64_t) r_weights); +// if (r_weights) +// print_global_mem(r_weights, 1, 1, 1); +// DBGI(r_weights_stride); + + kernel<<>> + ( + a, + b->cuda_q_weight, + b->cuda_gptq_qzeros, + b->cuda_gptq_scales, + c, + size_m, + size_n, + size_k, + b->groups, + b->gptq_groupsize, + b->cuda_q_perm, + b->rows_4, + clear, + r_weights, + r_weights_stride + ); + } +} + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear, + half* temp_dq, + bool force_cuda, + const half* r_weights, + const int r_weights_stride, + bool mul_r_weights +) +{ + if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) + { + // Reconstruct FP16 matrix, then cuBLAS + + if (!temp_dq) temp_dq = b->temp_dq; + b->reconstruct(temp_dq); + + //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); + + const half alpha = __float2half(1.0f); + const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); + cublasHgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha, temp_dq, size_n, + a, size_k, + &beta, c, size_n); + + //const float alpha = 1.0f; + //const float beta = clear ? 0.0f : 1.0f; + //cublasSgemmEx(cublas_handle, + // CUBLAS_OP_N, + // CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n); + + //const float alpha = 1.0f; + //const float beta = clear ? 0.0f : 1.0f; + //cublasGemmEx(cublas_handle, + // CUBLAS_OP_N, CUBLAS_OP_N, + // size_n, size_m, size_k, + // &alpha, temp_dq, CUDA_R_16F, size_n, + // a, CUDA_R_16F, size_k, + // &beta, c, CUDA_R_16F, size_n, + // CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP); + } + else + { + // Quantized matmul + + int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX; + int max_chunks = size_m / block_m_size_max; + int last_chunk = max_chunks * block_m_size_max; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) + { + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights); + } + + if (last_chunk_size) + { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights); + } + } +} + +__global__ void clear_kernel +( + half* __restrict__ c, + const int size_m, + const int size_n +) +{ + int m = blockIdx.y; + int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8; + if (n >= size_n) return; + int4* c_ptr = (int4*)(c + m * size_n + n); + *c_ptr = {}; +} + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +) +{ +// dim3 blockDim, gridDim; +// blockDim.x = CLEAR_N_SIZE; +// blockDim.y = 1; +// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); +// gridDim.y = size_m; +// clear_kernel<<>>(c, size_m, size_n); +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh new file mode 100644 index 00000000..e49457f3 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh @@ -0,0 +1,36 @@ +#ifndef _q_gemm_cuh +#define _q_gemm_cuh + +#include +#include +#include +#include +#include + +#include "q_matrix.cuh" + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear = false, + half* reconstruct = NULL, + bool force_cuda = false, + const half* r_weights = NULL, + const int r_weights_stride = 0, + bool mul_r_weights = false +); + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +); + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh new file mode 100644 index 00000000..9cd2ba01 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh @@ -0,0 +1,580 @@ +#include "compat.cuh" + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) +{ + // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 + + float result = {}; + #pragma unroll + for (int i = 0; i < 4; i++) + { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + + +typedef void (*fp_gemm_half_q_half_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const uint16_t*, + const uint16_t*, + const int, + const int, + const int, + const int, + const int, + const int, + const bool, + const half*, + const int +); + +template +__global__ void gemm_half_q_half_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const uint16_t* __restrict__ b_q_group_map, + const uint16_t* __restrict__ b_q_perm, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2, + const bool clear, + const half* r_weights, + const int r_weights_stride +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; + + int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k); + int n = offset_n + t * 4; + + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + + // Preload block_a + + __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + half a0 = a_ptr[b_q_perm[offset_k + t]]; +// half a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Clear + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + //int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + +// if (offset_m == 0 && t == 0) +// DBGI2(offset_k, group); + + // Preload scales + + half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; + + //int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + int temp_k = offset_k; + for (int g = 0; temp_k < end_k; g++) + { + int qscales[4]; + b_q_scale_.item4(qscales, group + g, n); + qscales[0]++; + qscales[1]++; + qscales[2]++; + qscales[3]++; + half maxscale = b_q_scale_max[group + g]; + scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); + scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); + scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); + scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); + temp_k += b_q_group_map[temp_k * 2 + 1]; + } + + // a, b offset + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = EXL2_BLOCK_KN_SIZE; + + // Initial group + + int scales_idx = 0; + half qs_h0 = scales[scales_idx][0]; + half qs_h1 = scales[scales_idx][1]; + half qs_h2 = scales[scales_idx][2]; + half qs_h3 = scales[scales_idx][3]; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + // Column result + + half block_c[m_count][4] = {}; + + // Dequantize groups + + int k = offset_k; + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[2]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 2; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 16; + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[5]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[3] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[4] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n); + dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n); + dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n); + dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_4bit_8(load_int4[0].x, dq[0], size_n); + dequant_4bit_8(load_int4[0].y, dq[1], size_n); + dequant_4bit_8(load_int4[0].z, dq[2], size_n); + dequant_4bit_8(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_2bit_16(load_int4[0].x, dq[0], size_n); + dequant_2bit_16(load_int4[0].y, dq[1], size_n); + dequant_2bit_16(load_int4[0].z, dq[2], size_n); + dequant_2bit_16(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + + a_ptr += 16; + } + k += 16; + } + + // Accumulate column sums in c + + for (int m = 0; m < m_count; m++) + { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + + atomicAdd(out , result01); + atomicAdd(out + 1, result23); +// *out = result01; +// *(out + 1) = result23; + } +} + +template +struct map_m_count_exl2 { + static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) + { + #if EXL2_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights) +{ + if (!r_weights && !mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_exl2< true, true>::pick_gemm_half_q_half_kernel(m_count); + return NULL; +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh new file mode 100644 index 00000000..f816fd9d --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -0,0 +1,273 @@ +#include "compat.cuh" + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return result; +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int, + const uint16_t*, + const int, + const bool, + const half*, + const int +); + +template +__global__ void gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int groupsize, + const uint16_t* __restrict__ b_q_perm, + const int rows_4, + const bool clear, + const half* r_weights, + const int r_weights_stride +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE; + + int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + if constexpr (use_r_weights) + { + uint16_t any_w = 0; + const half* w_ptr = r_weights; + for (int m = 0; m < m_count; ++m) + { + weights[m].as_half = *w_ptr; + w_ptr += r_weights_stride; + any_w |= weights[m].as_uint16; + } + if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!) + } + + // Preload block_a + + __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = GPTQ_BLOCK_KN_SIZE; + + // Initial group + + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + +// __syncthreads(); + + // Column result + + half2 block_c[m_count][4] = {}; + + // Dequantize and multiply + + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; } + block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0])); + half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1])); + half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2])); + half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3])); + half2 result01 = __halves2half2(result0, result1); + half2 result23 = __halves2half2(result2, result3); + + if constexpr (mul_r_weights) + { + half2 w_mul2 = __half2half2(weights[m].as_half); + result01 = __hmul2(result01, w_mul2); + result23 = __hmul2(result23, w_mul2); + } + + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +template +struct map_m_count_gptq { + static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count) + { + #if GPTQ_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>; + #endif + #if GPTQ_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights) +{ + if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); + if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); + return NULL; +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu new file mode 100644 index 00000000..f7a91e29 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -0,0 +1,650 @@ +#include "q_matrix.cuh" +#include "matrix_view.cuh" +#include "util.cuh" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define BLOCK_KN_SIZE 128 + +#define THREADS_X 32 +#define THREADS_Y 32 + +// Shuffle quantized data on load + +__global__ void shuffle_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2 +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } + while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; } + while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; } + while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } + while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } + while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +} + + +// QMatrix constructor + +QMatrix::QMatrix +( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + uint16_t* _q_group_map, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq +) : + device(_device), + height(_height), + width(_width), + groups(_groups), + temp_dq(_temp_dq) +{ + cudaSetDevice(device); + + failed = false; + + cuda_q_weight = _q_weight; + cuda_q_perm = _q_perm; + cuda_q_invperm = _q_invperm; + cuda_q_scale = _q_scale; + cuda_q_scale_max = _q_scale_max; + cuda_q_groups = _q_groups; + cuda_q_group_map = _q_group_map; + cuda_gptq_qzeros = _gptq_qzeros; + cuda_gptq_scales = _gptq_scales; + + is_gptq = (_gptq_qzeros != NULL); + + if (is_gptq) + { + gptq_groupsize = 1; + while (gptq_groupsize * groups < height) gptq_groupsize *= 2; + } + + // Create group map + + rows_8 = 0; + rows_6 = 0; + rows_5 = 0; + rows_4 = 0; + rows_3 = 0; + rows_2 = 0; + + if (!is_gptq) + { + uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); + cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); + + int row = 0; + for (int i = 0; i < groups; i++) + { + int bits = cpu_q_groups[i * 2]; + + int rows; + if (i < groups - 1) + { + int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; + rows = qrows * 32 / bits; + } + else rows = height - row; + + if (bits == 8) rows_8 += rows; + if (bits == 6) rows_6 += rows; + if (bits == 5) rows_5 += rows; + if (bits == 4) rows_4 += rows; + if (bits == 3) rows_3 += rows; + if (bits == 2) rows_2 += rows; + row += rows; + } + + free(cpu_q_groups); + + rows_6 += rows_8; + rows_5 += rows_6; + rows_4 += rows_5; + rows_3 += rows_4; + rows_2 += rows_3; + } + else + { + rows_4 = height; + rows_3 = height; + rows_2 = height; + + if (_gptq_g_idx) + { + if (!make_sequential(_gptq_g_idx)) + { + failed = true; + //printf("FAIL\n"); + return; + } + } + } + +// DBGI(rows_8); +// DBGI(rows_6); +// DBGI(rows_5); +// DBGI(rows_4); +// DBGI(rows_3); +// DBGI(rows_2); + + // Shuffle quantized data + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); +} + +QMatrix::~QMatrix() +{ +} + +// Reconstruct b[k,n] (GPTQ) + +__global__ void reconstruct_gptq_kernel +( + const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + //const uint16_t* __restrict__ b_q_groups, + const int size_k, + const int size_n, + const int groupsize, + const int groups, + half* __restrict__ b, + const int rows_4 +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + + +// Reconstruct b[k,n] + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + const uint16_t* __restrict__ b_q_group_map, + const int size_k, + const int size_n, + //const int groupsize, + const int groups, + half* __restrict__ b, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2 +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x; + + // Preload remapping table + + int t = threadIdx.x; + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + + // Column + + int n = offset_n + t; + if (n >= size_n) return; + + // Find initial group + + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + half2 qs_h2 = __halves2half2(qs_h, qs_h); + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int k = offset_k; + int lk = 0; + + __syncthreads(); + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + dequant_8bit_8(q_0, q_1, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 2; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + dequant_6bit_16(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + uint32_t q_3 = *b_ptr; b_ptr += size_n; + uint32_t q_4 = *b_ptr; b_ptr += size_n; + dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + dequant_4bit_8(q_0, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + uint32_t q_1 = *b_ptr; b_ptr += size_n; + uint32_t q_2 = *b_ptr; b_ptr += size_n; + dequant_3bit_32(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); } + for (int p = 0; p < 1; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; b_ptr += size_n; + dequant_2bit_16(q_0, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*) dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 16; + } +} + +void QMatrix::reconstruct(half* out) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (!is_gptq) + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_kernel<<>> + ( + cuda_q_weight, + cuda_q_perm, + cuda_q_scale, + cuda_q_scale_max, + cuda_q_group_map, + height, + width, + //groupsize, + groups, + out, + rows_8, + rows_6, + rows_5, + rows_4, + rows_3, + rows_2 + ); + } + else + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); + reconstruct_gptq_kernel<<>> + ( + cuda_q_weight, + cuda_q_perm, + cuda_gptq_qzeros, + cuda_gptq_scales, + //const uint16_t* __restrict__ b_q_groups, + height, + width, + gptq_groupsize, + groups, + out, + rows_4 + ); + } +} + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint16_t* __restrict__ q_perm, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int q_perm_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) +{ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + uint32_t* cuda_new_qweight = NULL; + cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + if (err != cudaSuccess) { + cudaError_t cuda_status = cudaGetLastError(); // Clear error + return false; + } + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Reduce to uint16_t + + uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map; + uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv; + for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row]; + for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row]; + + // Move to CUDA + + cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 8; + + make_sequential_kernel<<>> + ( + cuda_q_weight, + cuda_new_qweight, + cuda_q_perm, + height / 8, + width + ); + + // Replace qweights + + cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); + + return true; +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh new file mode 100644 index 00000000..d36b8d66 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh @@ -0,0 +1,75 @@ +#ifndef _q_matrix_cuh +#define _q_matrix_cuh + +#include +#include +#include +#include + +#define MAX_SUPERGROUPS 16 + +class QMatrix +{ +public: + + int device; + bool is_gptq; + + int height; + int width; + int groups; + int gptq_groupsize; + + int rows_8; + int rows_6; + int rows_5; + int rows_4; + int rows_3; + int rows_2; + + uint32_t* cuda_q_weight = NULL; + uint16_t* cuda_q_perm = NULL; + uint16_t* cuda_q_invperm = NULL; + uint32_t* cuda_q_scale = NULL; + half* cuda_q_scale_max = NULL; + uint16_t* cuda_q_groups = NULL; + uint16_t* cuda_q_group_map = NULL; + uint32_t* cuda_gptq_qzeros = NULL; + half* cuda_gptq_scales = NULL; + + half* temp_dq; + + bool failed; + + QMatrix + ( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + uint16_t* _q_group_map, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq + ); + + ~QMatrix(); + + void reconstruct(half* out); + bool make_sequential(const uint32_t* cpu_g_idx); + +private: + +}; + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh new file mode 100644 index 00000000..90c18a0c --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh @@ -0,0 +1,103 @@ +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_2BIT == 1 + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 2.0f); + const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z4 = __halves2half2(z4_, z4_); + const half2 z16 = __halves2half2(z16_, z16_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +#else + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh new file mode 100644 index 00000000..10117376 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh @@ -0,0 +1,169 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_3BIT == 1 + +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 4.0f); + const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y8, z8); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hfma2( q3.as_half2, y8, z8); + dq[ 4] = __hfma2( q4.as_half2, y64, z64); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hfma2( q6.as_half2, y8, z8); + dq[ 7] = __hadd2( q7.as_half2, z1); + dq[ 8] = __hfma2( q8.as_half2, y8, z8); + dq[ 9] = __hfma2( q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#else + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + half dqh[32]; + for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4); + dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4); + for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4); + dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4); + for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4); + + for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh new file mode 100644 index 00000000..ad95edb4 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh @@ -0,0 +1,227 @@ +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_4BIT == 1 + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +#else + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1)[2], + half2 (&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z = __hmul(z, scale); + z1[0] = __half2half2(z); + y1[0] = __half2half2(scale); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1)[2], + half2(&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z1[0] = __half2half2(z); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1)[2], + half2 (&y1)[2], + int stride, + bool scaled +) +{ + half2 dqh2[8]; + + uint32_t qa = q_0; + for (int i = 0; i < 4; i++) + { + half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; + half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; + dqh2[i] = __halves2half2(d0, d1); + } + + if (scaled) + { + dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); + dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); + dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); + dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); + } + else + { + dq[0] = __hadd2(dqh2[0], z1[0]); + dq[1] = __hadd2(dqh2[1], z1[0]); + dq[2] = __hadd2(dqh2[2], z1[0]); + dq[3] = __hadd2(dqh2[3], z1[0]); + } +} + +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh new file mode 100644 index 00000000..78d81f92 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh @@ -0,0 +1,207 @@ +#ifndef _qdq_5_cuh +#define _qdq_5_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_5BIT == 1 + +// Permutation: +// +// v5555533 33311111 u4444422 22200000 (u, v lsb) +// vbbbbb99 99977777 uaaaaa88 88866666 +// vhhhhhff fffddddd ugggggee eeeccccc +// vnnnnnll llljjjjj ummmmmkk kkkiiiii +// vtttttrr rrrppppp usssssqq qqqooooo + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + uint32_t qd = q[3 * stride]; + uint32_t qe = q[4 * stride]; + + // qa: 66555554 44443333 32222211 11100000 + // qb: ccccbbbb baaaaa99 99988888 77777666 + // qc: jiiiiihh hhhggggg fffffeee eedddddc + // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj + // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp + + uint32_t qf = qe >> 22; + qe <<= 8; + qe |= qd >> 24; + qd <<= 6; + qd |= qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: 555554 44443333 32222211 11100000 + // qb: bbbbba aaaa9999 98888877 77766666 + // qc: hhhhhg ggggffff feeeeedd dddccccc + // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii + // qe: ttttts ssssrrrr rqqqqqpp pppooooo + // qf: vv vvvuuuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + uint32_t zd = 0; + uint32_t ze = 0; + + for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); } + + // za: 5555533 33311111 4444422 22200000 + // zb: bbbbb99 99977777 aaaaa88 88866666 + // zc: hhhhhff fffddddd gggggee eeeccccc + // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii + // ze: tttttrr rrrppppp sssssqq qqqooooo + // qf: vv vvvuuuuu + + za |= ((qf & 0x001) >> 0) << 15; + zb |= ((qf & 0x002) >> 1) << 15; + zc |= ((qf & 0x004) >> 2) << 15; + zd |= ((qf & 0x008) >> 3) << 15; + ze |= ((qf & 0x010) >> 4) << 15; + za |= ((qf & 0x020) >> 5) << 31; + zb |= ((qf & 0x040) >> 6) << 31; + zc |= ((qf & 0x080) >> 7) << 31; + zd |= ((qf & 0x100) >> 8) << 31; + ze |= ((qf & 0x200) >> 9) << 31; + + // za: v5555533 33311111 u4444422 22200000 (u, v lsb) + // zb: vbbbbb99 99977777 uaaaaa88 88866666 + // zc: vhhhhhff fffddddd ugggggee eeeccccc + // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii + // ze: vtttttrr rrrppppp usssssqq qqqooooo + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; + q[3 * stride] = zd; + q[4 * stride] = ze; +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y32_ = __float2half_rn(1.0f / 32.0f); + const half2 y32 = __halves2half2(y32_, y32_); + const half z1_ = __float2half_rn(-1024.0f - 16.0f); + const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z32 = __halves2half2(z32_, z32_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + uint32_t qd = q_3; + uint32_t qe = q_4; + + half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024 + qa >>= 10; + half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024 + qa >>= 5; + qa &= 0x00010001; + half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024 + half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024 + qb >>= 10; + half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024 + qb >>= 4; + qb &= 0x00020002; + half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024 + half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024 + qc >>= 10; + half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024 + qc >>= 3; + qc &= 0x00040004; + half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024 + half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024 + qd >>= 10; + half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024 + qd >>= 2; + qd &= 0x00080008; + half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024 + qe >>= 10; + half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024 + qe >>= 1; + qe &= 0x00100010; + half2_uint32 q15((qa | qb | qc | qd | qe) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y32, z32); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hadd2( q3.as_half2, z1); + dq[ 4] = __hfma2( q4.as_half2, y32, z32); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hadd2( q6.as_half2, z1); + dq[ 7] = __hfma2( q7.as_half2, y32, z32); + dq[ 8] = __hadd2( q8.as_half2, z1); + dq[ 9] = __hadd2( q9.as_half2, z1); + dq[10] = __hfma2(q10.as_half2, y32, z32); + dq[11] = __hadd2(q11.as_half2, z1); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y32, z32); + dq[14] = __hadd2(q14.as_half2, z1); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#else + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + half dqh[32]; + for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16); + dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16); + for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16); + dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16); + for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16); + dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16); + for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16); + dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16); + for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16); + + for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh new file mode 100644 index 00000000..562fe695 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh @@ -0,0 +1,42 @@ +#ifndef _qdq_6_cuh +#define _qdq_6_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_6BIT == 1 + + // Not implemented + +#else + +__forceinline__ __device__ void shuffle_6bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_6bit_16 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32); + dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); + for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32); + dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); + for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh new file mode 100644 index 00000000..6e6bedbd --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh @@ -0,0 +1,38 @@ +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" +#include "../../config.h" + +#if QMODE_8BIT == 1 + + // Not implemented + +#else + +__forceinline__ __device__ void shuffle_8bit_4 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_8bit_8 +( + const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh new file mode 100644 index 00000000..cac9df9c --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh @@ -0,0 +1,53 @@ +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh new file mode 100644 index 00000000..e167bc23 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh @@ -0,0 +1,54 @@ +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include +#include + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#define DBGS(__x) printf("%s\n", __x) +#define DBGI(__x) printf("%s: %i\n", #__x, __x) +#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) +#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGX(__x) printf("%s: %x\n", #__x, __x) +#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y) +#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF(__x) printf("%s: %f\n", #__x, __x) +#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) +#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x)) +#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y)) +#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z)) + +#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y)) +#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z)) + +__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale) +{ + half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f)); + qs_h = __hmul(qs_h, qs_h); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ float clamp(float x, float a, float b) +{ + return fmaxf(a, fminf(b, x)); +} + +#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } +inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} + +void print_global_mem(const half* ptr, int rows, int columns, int stride); + +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/ext.cpp b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp new file mode 100644 index 00000000..ff4e1851 --- /dev/null +++ b/server/exllamav2_kernels/exllamav2_kernels/ext.cpp @@ -0,0 +1,139 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "config.h" + +#include "cuda/q_matrix.cuh" +#include "cuda/q_gemm.cuh" + +#include "cpp/util.h" + +// Some decluttering macros + +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") + + +// Quant matrix + +uintptr_t make_q_matrix +( + torch::Tensor q_weight, + torch::Tensor q_perm, + torch::Tensor q_invperm, + torch::Tensor q_scale, + torch::Tensor q_scale_max, + torch::Tensor q_groups, + torch::Tensor q_group_map, + torch::Tensor gptq_qzeros, + torch::Tensor gptq_scales, + torch::Tensor gptq_g_idx, + torch::Tensor temp_dq +) +{ + TORCH_CHECK_DTYPE(q_weight, kInt); + TORCH_CHECK_DTYPE_OPT(q_perm, kShort); + TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); + TORCH_CHECK_DTYPE_OPT(q_scale, kInt); + TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); + TORCH_CHECK_DTYPE_OPT(q_groups, kShort); + TORCH_CHECK_DTYPE_OPT(q_group_map, kShort); + TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); + TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); + TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); + + TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); + + int device = q_weight.device().index(); + int width = q_weight.size(1); + int groups; + int height; + + if (!q_scale.device().is_meta()) + { + TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8); + TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1); + groups = q_scale.size(0); + height = q_invperm.size(0); + } + else + { + TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8); + TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1); + groups = gptq_qzeros.size(0); + height = q_weight.size(0) * 8; + } + + TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer") + + QMatrix* m = new QMatrix + ( + device, + height, + width, + groups, + (uint32_t*) q_weight.data_ptr(), + q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), + q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), + q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), + q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), + q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), + q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(), + gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), + gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), + gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), + (half*) temp_dq.data_ptr() + ); + + if (m->failed) throw std::runtime_error("CUDA out of memory"); + + return reinterpret_cast (m); +} + +void gemm_half_q_half +( + torch::Tensor a, + uintptr_t b, + torch::Tensor c, + bool force_cuda +) +{ + QMatrix* qm = reinterpret_cast (b); + + TORCH_CHECK_DTYPE(a, kHalf); + TORCH_CHECK_DTYPE(c, kHalf); + TORCH_CHECK_SHAPES(a, 0, c, 0, 1); + TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes") + TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + + gemm_half_q_half_cuda + ( + at::cuda::getCurrentCUDABlasHandle(), + (const half*) a.data_ptr(), + qm, + (half*) c.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + true, + NULL, + force_cuda + ); +} + +// Bindings + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("make_q_matrix", &make_q_matrix, "make_q_matrix"); + m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half"); +} diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py new file mode 100644 index 00000000..4a16b546 --- /dev/null +++ b/server/exllamav2_kernels/setup.py @@ -0,0 +1,28 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_cuda_cflags = ["-lineinfo", "-O3"] + +if torch.version.hip: + extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"] + +extra_compile_args = { + "nvcc": extra_cuda_cflags, +} + +setup( + name="exllamav2_kernels", + ext_modules=[ + CUDAExtension( + name="exllamav2_kernels", + sources=[ + "exllamav2_kernels/ext.cpp", + "exllamav2_kernels/cuda/q_matrix.cu", + "exllamav2_kernels/cuda/q_gemm.cu", + ], + extra_compile_args=extra_compile_args, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/poetry.lock b/server/poetry.lock new file mode 100644 index 00000000..14d0446c --- /dev/null +++ b/server/poetry.lock @@ -0,0 +1,3492 @@ +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. + +[[package]] +name = "accelerate" +version = "0.29.3" +description = "Accelerate" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "accelerate-0.29.3-py3-none-any.whl", hash = "sha256:99d633d4b6126817c5e554487406748be95c8d1d1e659dd2fd60657e35f532dd"}, + {file = "accelerate-0.29.3.tar.gz", hash = "sha256:1a5a845b06b24b41736b219b2b20fd021ca5dff4070a252445fd6de736e347ac"}, +] + +[package.dependencies] +huggingface-hub = "*" +numpy = ">=1.17" +packaging = ">=20.0" +psutil = "*" +pyyaml = "*" +safetensors = ">=0.3.1" +torch = ">=1.10.0" + +[package.extras] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"] +rich = ["rich"] +sagemaker = ["sagemaker"] +test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"] +test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] +testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] + +[[package]] +name = "aiohttp" +version = "3.9.5" +description = "Async http client/server framework (asyncio)" +optional = true +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"}, + {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"}, + {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"}, + {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"}, + {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"}, + {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"}, + {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"}, + {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"}, + {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"}, + {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"}, + {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"}, + {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"}, + {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"}, + {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"}, + {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"}, + {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"}, + {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:694d828b5c41255e54bc2dddb51a9f5150b4eefa9886e38b52605a05d96566e8"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0605cc2c0088fcaae79f01c913a38611ad09ba68ff482402d3410bf59039bfb8"}, + {file = "aiohttp-3.9.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4558e5012ee03d2638c681e156461d37b7a113fe13970d438d95d10173d25f78"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dbc053ac75ccc63dc3a3cc547b98c7258ec35a215a92bd9f983e0aac95d3d5b"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4109adee842b90671f1b689901b948f347325045c15f46b39797ae1bf17019de"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6ea1a5b409a85477fd8e5ee6ad8f0e40bf2844c270955e09360418cfd09abac"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3c2890ca8c59ee683fd09adf32321a40fe1cf164e3387799efb2acebf090c11"}, + {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3916c8692dbd9d55c523374a3b8213e628424d19116ac4308e434dbf6d95bbdd"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8d1964eb7617907c792ca00b341b5ec3e01ae8c280825deadbbd678447b127e1"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d5ab8e1f6bee051a4bf6195e38a5c13e5e161cb7bad83d8854524798bd9fcd6e"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:52c27110f3862a1afbcb2af4281fc9fdc40327fa286c4625dfee247c3ba90156"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:7f64cbd44443e80094309875d4f9c71d0401e966d191c3d469cde4642bc2e031"}, + {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8b4f72fbb66279624bfe83fd5eb6aea0022dad8eec62b71e7bf63ee1caadeafe"}, + {file = "aiohttp-3.9.5-cp38-cp38-win32.whl", hash = "sha256:6380c039ec52866c06d69b5c7aad5478b24ed11696f0e72f6b807cfb261453da"}, + {file = "aiohttp-3.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:da22dab31d7180f8c3ac7c7635f3bcd53808f374f6aa333fe0b0b9e14b01f91a"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"}, + {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"}, + {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"}, + {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"}, + {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"}, + {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"}, + {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"}, +] + +[package.dependencies] +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns", "brotlicffi"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = true +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = true +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = true +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + +[[package]] +name = "attrs" +version = "23.2.0" +description = "Classes Without Boilerplate" +optional = true +python-versions = ">=3.7" +files = [ + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] + +[[package]] +name = "backoff" +version = "2.2.1" +description = "Function decoration for backoff and retry" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, + {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, +] + +[[package]] +name = "bitsandbytes" +version = "0.43.1" +description = "k-bit optimizers and matrix multiplication routines." +optional = true +python-versions = "*" +files = [ + {file = "bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a81c826d576d6d691c7b4a7491c8fdc0f37f769795d6ca2e54afa605d2c260a3"}, + {file = "bitsandbytes-0.43.1-py3-none-win_amd64.whl", hash = "sha256:52c1c7189a6ca006555a9663e544e75f40520a97a26e075411f9f9aca0771fcd"}, +] + +[package.dependencies] +numpy = "*" +torch = "*" + +[package.extras] +benchmark = ["matplotlib", "pandas"] +test = ["scipy"] + +[[package]] +name = "certifi" +version = "2024.2.2" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, + {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.3.2" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, + {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, +] + +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "cloudpickle" +version = "3.0.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = true +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, + {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "datasets" +version = "2.14.4" +description = "HuggingFace community-driven open-source library of datasets" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, + {file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, +] + +[package.dependencies] +aiohttp = "*" +dill = ">=0.3.0,<0.3.8" +fsspec = {version = ">=2021.11.1", extras = ["http"]} +huggingface-hub = ">=0.14.0,<1.0.0" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +pyarrow = ">=8.0.0" +pyyaml = ">=5.1" +requests = ">=2.19.0" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] +audio = ["librosa", "soundfile (>=0.12.1)"] +benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] +jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] +quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] +s3 = ["s3fs"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +torch = ["torch"] +vision = ["Pillow (>=6.2.1)"] + +[[package]] +name = "deprecated" +version = "1.2.14" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, + {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"}, +] + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] + +[[package]] +name = "dill" +version = "0.3.7" +description = "serialize all of Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, + {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] + +[[package]] +name = "diskcache" +version = "5.6.3" +description = "Disk Cache -- Disk and file backed persistent cache." +optional = true +python-versions = ">=3" +files = [ + {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"}, + {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, +] + +[[package]] +name = "einops" +version = "0.6.1" +description = "A new flavour of deep learning operations" +optional = false +python-versions = ">=3.7" +files = [ + {file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"}, + {file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.1" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, + {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.14.0" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, + {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +typing = ["typing-extensions (>=4.8)"] + +[[package]] +name = "frozenlist" +version = "1.4.1" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = true +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + +[[package]] +name = "fsspec" +version = "2024.5.0" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"}, + {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"}, +] + +[package.dependencies] +aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +dev = ["pre-commit", "ruff"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] +test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] +test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] +tqdm = ["tqdm"] + +[[package]] +name = "googleapis-common-protos" +version = "1.63.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis-common-protos-1.63.0.tar.gz", hash = "sha256:17ad01b11d5f1d0171c06d3ba5c04c54474e883b66b949722b4938ee2694ef4e"}, + {file = "googleapis_common_protos-1.63.0-py2.py3-none-any.whl", hash = "sha256:ae45f75702f7c08b541f750854a678bd8f534a1a6bace6afe975f1d0a82d6632"}, +] + +[package.dependencies] +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + +[[package]] +name = "grpc-interceptor" +version = "0.15.4" +description = "Simplifies gRPC interceptors" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "grpc-interceptor-0.15.4.tar.gz", hash = "sha256:1f45c0bcb58b6f332f37c637632247c9b02bc6af0fdceb7ba7ce8d2ebbfb0926"}, + {file = "grpc_interceptor-0.15.4-py3-none-any.whl", hash = "sha256:0035f33228693ed3767ee49d937bac424318db173fef4d2d0170b3215f254d9d"}, +] + +[package.dependencies] +grpcio = ">=1.49.1,<2.0.0" + +[package.extras] +testing = ["protobuf (>=4.21.9)"] + +[[package]] +name = "grpcio" +version = "1.64.0" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.64.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3b09c3d9de95461214a11d82cc0e6a46a6f4e1f91834b50782f932895215e5db"}, + {file = "grpcio-1.64.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:7e013428ab472892830287dd082b7d129f4d8afef49227a28223a77337555eaa"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:02cc9cc3f816d30f7993d0d408043b4a7d6a02346d251694d8ab1f78cc723e7e"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f5de082d936e0208ce8db9095821361dfa97af8767a6607ae71425ac8ace15c"}, + {file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b7bf346391dffa182fba42506adf3a84f4a718a05e445b37824136047686a1"}, + {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b2cbdfba18408389a1371f8c2af1659119e1831e5ed24c240cae9e27b4abc38d"}, + {file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:aca4f15427d2df592e0c8f3d38847e25135e4092d7f70f02452c0e90d6a02d6d"}, + {file = "grpcio-1.64.0-cp310-cp310-win32.whl", hash = "sha256:7c1f5b2298244472bcda49b599be04579f26425af0fd80d3f2eb5fd8bc84d106"}, + {file = "grpcio-1.64.0-cp310-cp310-win_amd64.whl", hash = "sha256:73f84f9e5985a532e47880b3924867de16fa1aa513fff9b26106220c253c70c5"}, + {file = "grpcio-1.64.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2a18090371d138a57714ee9bffd6c9c9cb2e02ce42c681aac093ae1e7189ed21"}, + {file = "grpcio-1.64.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:59c68df3a934a586c3473d15956d23a618b8f05b5e7a3a904d40300e9c69cbf0"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b52e1ec7185512103dd47d41cf34ea78e7a7361ba460187ddd2416b480e0938c"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d598b5d5e2c9115d7fb7e2cb5508d14286af506a75950762aa1372d60e41851"}, + {file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01615bbcae6875eee8091e6b9414072f4e4b00d8b7e141f89635bdae7cf784e5"}, + {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0b2dfe6dcace264807d9123d483d4c43274e3f8c39f90ff51de538245d7a4145"}, + {file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7f17572dc9acd5e6dfd3014d10c0b533e9f79cd9517fc10b0225746f4c24b58e"}, + {file = "grpcio-1.64.0-cp311-cp311-win32.whl", hash = "sha256:6ec5ed15b4ffe56e2c6bc76af45e6b591c9be0224b3fb090adfb205c9012367d"}, + {file = "grpcio-1.64.0-cp311-cp311-win_amd64.whl", hash = "sha256:597191370951b477b7a1441e1aaa5cacebeb46a3b0bd240ec3bb2f28298c7553"}, + {file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"}, + {file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"}, + {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"}, + {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"}, + {file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"}, + {file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"}, + {file = "grpcio-1.64.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c46fb6bfca17bfc49f011eb53416e61472fa96caa0979b4329176bdd38cbbf2a"}, + {file = "grpcio-1.64.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3d2004e85cf5213995d09408501f82c8534700d2babeb81dfdba2a3bff0bb396"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6d5541eb460d73a07418524fb64dcfe0adfbcd32e2dac0f8f90ce5b9dd6c046c"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f279ad72dd7d64412e10f2443f9f34872a938c67387863c4cd2fb837f53e7d2"}, + {file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85fda90b81da25993aa47fae66cae747b921f8f6777550895fb62375b776a231"}, + {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a053584079b793a54bece4a7d1d1b5c0645bdbee729215cd433703dc2532f72b"}, + {file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:579dd9fb11bc73f0de061cab5f8b2def21480fd99eb3743ed041ad6a1913ee2f"}, + {file = "grpcio-1.64.0-cp38-cp38-win32.whl", hash = "sha256:23b6887bb21d77649d022fa1859e05853fdc2e60682fd86c3db652a555a282e0"}, + {file = "grpcio-1.64.0-cp38-cp38-win_amd64.whl", hash = "sha256:753cb58683ba0c545306f4e17dabf468d29cb6f6b11832e1e432160bb3f8403c"}, + {file = "grpcio-1.64.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:2186d76a7e383e1466e0ea2b0febc343ffeae13928c63c6ec6826533c2d69590"}, + {file = "grpcio-1.64.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0f30596cdcbed3c98024fb4f1d91745146385b3f9fd10c9f2270cbfe2ed7ed91"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:d9171f025a196f5bcfec7e8e7ffb7c3535f7d60aecd3503f9e250296c7cfc150"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf4c8daed18ae2be2f1fc7d613a76ee2a2e28fdf2412d5c128be23144d28283d"}, + {file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3550493ac1d23198d46dc9c9b24b411cef613798dc31160c7138568ec26bc9b4"}, + {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3161a8f8bb38077a6470508c1a7301cd54301c53b8a34bb83e3c9764874ecabd"}, + {file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e8fabe2cc57a369638ab1ad8e6043721014fdf9a13baa7c0e35995d3a4a7618"}, + {file = "grpcio-1.64.0-cp39-cp39-win32.whl", hash = "sha256:31890b24d47b62cc27da49a462efe3d02f3c120edb0e6c46dcc0025506acf004"}, + {file = "grpcio-1.64.0-cp39-cp39-win_amd64.whl", hash = "sha256:5a56797dea8c02e7d3a85dfea879f286175cf4d14fbd9ab3ef2477277b927baa"}, + {file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.64.0)"] + +[[package]] +name = "grpcio-reflection" +version = "1.62.2" +description = "Standard Protobuf Reflection Service for gRPC" +optional = false +python-versions = ">=3.6" +files = [ + {file = "grpcio-reflection-1.62.2.tar.gz", hash = "sha256:2dd44806d68d0006636529bda573012b19a42281478c2d051cdaaebb91e2516c"}, + {file = "grpcio_reflection-1.62.2-py3-none-any.whl", hash = "sha256:68e8dff3617a9afaf7c462c688f7ca62b55323f497c662abf9965f2953508885"}, +] + +[package.dependencies] +grpcio = ">=1.62.2" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-status" +version = "1.62.2" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.6" +files = [ + {file = "grpcio-status-1.62.2.tar.gz", hash = "sha256:62e1bfcb02025a1cd73732a2d33672d3e9d0df4d21c12c51e0bbcaf09bab742a"}, + {file = "grpcio_status-1.62.2-py3-none-any.whl", hash = "sha256:206ddf0eb36bc99b033f03b2c8e95d319f0044defae9b41ae21408e7e0cda48f"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.62.2" +protobuf = ">=4.21.6" + +[[package]] +name = "grpcio-tools" +version = "1.62.2" +description = "Protobuf code generator for gRPC" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-tools-1.62.2.tar.gz", hash = "sha256:5fd5e1582b678e6b941ee5f5809340be5e0724691df5299aae8226640f94e18f"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:1679b4903aed2dc5bd8cb22a452225b05dc8470a076f14fd703581efc0740cdb"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:9d41e0e47dd075c075bb8f103422968a65dd0d8dc8613288f573ae91eb1053ba"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:987e774f74296842bbffd55ea8826370f70c499e5b5f71a8cf3103838b6ee9c3"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40cd4eeea4b25bcb6903b82930d579027d034ba944393c4751cdefd9c49e6989"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6746bc823958499a3cf8963cc1de00072962fb5e629f26d658882d3f4c35095"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2ed775e844566ce9ce089be9a81a8b928623b8ee5820f5e4d58c1a9d33dfc5ae"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bdc5dd3f57b5368d5d661d5d3703bcaa38bceca59d25955dff66244dbc987271"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win32.whl", hash = "sha256:3a8d6f07e64c0c7756f4e0c4781d9d5a2b9cc9cbd28f7032a6fb8d4f847d0445"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:e33b59fb3efdddeb97ded988a871710033e8638534c826567738d3edce528752"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:472505d030135d73afe4143b0873efe0dcb385bd6d847553b4f3afe07679af00"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:ec674b4440ef4311ac1245a709e87b36aca493ddc6850eebe0b278d1f2b6e7d1"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:184b4174d4bd82089d706e8223e46c42390a6ebac191073b9772abc77308f9fa"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c195d74fe98541178ece7a50dad2197d43991e0f77372b9a88da438be2486f12"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a34d97c62e61bfe9e6cff0410fe144ac8cca2fc979ad0be46b7edf026339d161"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb8453ae83a1db2452b7fe0f4b78e4a8dd32be0f2b2b73591ae620d4d784d3d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f989e5cebead3ae92c6abf6bf7b19949e1563a776aea896ac5933f143f0c45d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win32.whl", hash = "sha256:c48fabe40b9170f4e3d7dd2c252e4f1ff395dc24e49ac15fc724b1b6f11724da"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:8c616d0ad872e3780693fce6a3ac8ef00fc0963e6d7815ce9dcfae68ba0fc287"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:10cc3321704ecd17c93cf68c99c35467a8a97ffaaed53207e9b2da6ae0308ee1"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:9be84ff6d47fd61462be7523b49d7ba01adf67ce4e1447eae37721ab32464dd8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:d82f681c9a9d933a9d8068e8e382977768e7779ddb8870fa0cf918d8250d1532"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04c607029ae3660fb1624ed273811ffe09d57d84287d37e63b5b802a35897329"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72b61332f1b439c14cbd3815174a8f1d35067a02047c32decd406b3a09bb9890"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8214820990d01b52845f9fbcb92d2b7384a0c321b303e3ac614c219dc7d1d3af"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:462e0ab8dd7c7b70bfd6e3195eebc177549ede5cf3189814850c76f9a340d7ce"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win32.whl", hash = "sha256:fa107460c842e4c1a6266150881694fefd4f33baa544ea9489601810c2210ef8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:759c60f24c33a181bbbc1232a6752f9b49fbb1583312a4917e2b389fea0fb0f2"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:45db5da2bcfa88f2b86b57ef35daaae85c60bd6754a051d35d9449c959925b57"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:ab84bae88597133f6ea7a2bdc57b2fda98a266fe8d8d4763652cbefd20e73ad7"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7a49bccae1c7d154b78e991885c3111c9ad8c8fa98e91233de425718f47c6139"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7e439476b29d6dac363b321781a113794397afceeb97dad85349db5f1cb5e9a"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ea369c4d1567d1acdf69c8ea74144f4ccad9e545df7f9a4fc64c94fa7684ba3"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f955702dc4b530696375251319d05223b729ed24e8673c2129f7a75d2caefbb"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3708a747aa4b6b505727282ca887041174e146ae030ebcadaf4c1d346858df62"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:2ce149ea55eadb486a7fb75a20f63ef3ac065ee6a0240ed25f3549ce7954c653"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:58cbb24b3fa6ae35aa9c210fcea3a51aa5fef0cd25618eb4fd94f746d5a9b703"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:6413581e14a80e0b4532577766cf0586de4dd33766a31b3eb5374a746771c07d"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:47117c8a7e861382470d0e22d336e5a91fdc5f851d1db44fa784b9acea190d87"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f1ba79a253df9e553d20319c615fa2b429684580fa042dba618d7f6649ac7e4"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04a394cf5e51ba9be412eb9f6c482b6270bd81016e033e8eb7d21b8cc28fe8b5"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3c53b221378b035ae2f1881cbc3aca42a6075a8e90e1a342c2f205eb1d1aa6a1"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c384c838b34d1b67068e51b5bbe49caa6aa3633acd158f1ab16b5da8d226bc53"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win32.whl", hash = "sha256:19ea69e41c3565932aa28a202d1875ec56786aea46a2eab54a3b28e8a27f9517"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:1d768a5c07279a4c461ebf52d0cec1c6ca85c6291c71ec2703fe3c3e7e28e8c4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:5b07b5874187e170edfbd7aa2ca3a54ebf3b2952487653e8c0b0d83601c33035"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:d58389fe8be206ddfb4fa703db1e24c956856fcb9a81da62b13577b3a8f7fda7"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:7d8b4e00c3d7237b92260fc18a561cd81f1da82e8be100db1b7d816250defc66"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe08d2038f2b7c53259b5c49e0ad08c8e0ce2b548d8185993e7ef67e8592cca"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19216e1fb26dbe23d12a810517e1b3fbb8d4f98b1a3fbebeec9d93a79f092de4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b8574469ecc4ff41d6bb95f44e0297cdb0d95bade388552a9a444db9cd7485cd"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4f6f32d39283ea834a493fccf0ebe9cfddee7577bdcc27736ad4be1732a36399"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win32.whl", hash = "sha256:76eb459bdf3fb666e01883270beee18f3f11ed44488486b61cd210b4e0e17cc1"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:217c2ee6a7ce519a55958b8622e21804f6fdb774db08c322f4c9536c35fdce7c"}, +] + +[package.dependencies] +grpcio = ">=1.62.2" +protobuf = ">=4.21.6,<5.0dev" +setuptools = "*" + +[[package]] +name = "hf-transfer" +version = "0.1.6" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, + {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, + {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, + {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, + {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, + {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, + {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, + {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, + {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, + {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, + {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, + {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, + {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, + {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, +] + +[[package]] +name = "huggingface-hub" +version = "0.23.2" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"}, + {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +hf-transfer = ["hf-transfer (>=0.1.4)"] +inference = ["aiohttp", "minijinja (>=1.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors", "torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + +[[package]] +name = "idna" +version = "3.7" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "intel-openmp" +version = "2021.4.0" +description = "Intel OpenMP* Runtime Library" +optional = true +python-versions = "*" +files = [ + {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, +] + +[[package]] +name = "interegular" +version = "0.3.3" +description = "a regex intersection checker" +optional = true +python-versions = ">=3.7" +files = [ + {file = "interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c"}, + {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"}, +] + +[[package]] +name = "jinja2" +version = "3.1.4" +description = "A very fast and expressive template engine." +optional = true +python-versions = ">=3.7" +files = [ + {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, + {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = true +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + +[[package]] +name = "jsonschema" +version = "4.22.0" +description = "An implementation of JSON Schema validation for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.22.0-py3-none-any.whl", hash = "sha256:ff4cfd6b1367a40e7bc6411caec72effadd3db0bbe5017de188f2d6108335802"}, + {file = "jsonschema-4.22.0.tar.gz", hash = "sha256:5b22d434a45935119af990552c862e5d6d564e8f6601206b305a61fdf661a2b7"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + +[[package]] +name = "lark" +version = "1.1.9" +description = "a modern parsing library" +optional = true +python-versions = ">=3.6" +files = [ + {file = "lark-1.1.9-py3-none-any.whl", hash = "sha256:a0dd3a87289f8ccbb325901e4222e723e7d745dbfc1803eaf5f3d2ace19cf2db"}, + {file = "lark-1.1.9.tar.gz", hash = "sha256:15fa5236490824c2c4aba0e22d2d6d823575dcaf4cdd1848e34b6ad836240fba"}, +] + +[package.extras] +atomic-cache = ["atomicwrites"] +interegular = ["interegular (>=0.3.1,<0.4.0)"] +nearley = ["js2py"] +regex = ["regex"] + +[[package]] +name = "llvmlite" +version = "0.42.0" +description = "lightweight wrapper around basic LLVM functionality" +optional = true +python-versions = ">=3.9" +files = [ + {file = "llvmlite-0.42.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3366938e1bf63d26c34fbfb4c8e8d2ded57d11e0567d5bb243d89aab1eb56098"}, + {file = "llvmlite-0.42.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c35da49666a21185d21b551fc3caf46a935d54d66969d32d72af109b5e7d2b6f"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70f44ccc3c6220bd23e0ba698a63ec2a7d3205da0d848804807f37fc243e3f77"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f8d8717a9073b9e0246998de89929071d15b47f254c10eef2310b9aac033d"}, + {file = "llvmlite-0.42.0-cp310-cp310-win_amd64.whl", hash = "sha256:8d90edf400b4ceb3a0e776b6c6e4656d05c7187c439587e06f86afceb66d2be5"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ae511caed28beaf1252dbaf5f40e663f533b79ceb408c874c01754cafabb9cbf"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81e674c2fe85576e6c4474e8c7e7aba7901ac0196e864fe7985492b737dbab65"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb3975787f13eb97629052edb5017f6c170eebc1c14a0433e8089e5db43bcce6"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5bece0cdf77f22379f19b1959ccd7aee518afa4afbd3656c6365865f84903f9"}, + {file = "llvmlite-0.42.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e0c4c11c8c2aa9b0701f91b799cb9134a6a6de51444eff5a9087fc7c1384275"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:08fa9ab02b0d0179c688a4216b8939138266519aaa0aa94f1195a8542faedb56"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b2fce7d355068494d1e42202c7aff25d50c462584233013eb4470c33b995e3ee"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebe66a86dc44634b59a3bc860c7b20d26d9aaffcd30364ebe8ba79161a9121f4"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d47494552559e00d81bfb836cf1c4d5a5062e54102cc5767d5aa1e77ccd2505c"}, + {file = "llvmlite-0.42.0-cp312-cp312-win_amd64.whl", hash = "sha256:05cb7e9b6ce69165ce4d1b994fbdedca0c62492e537b0cc86141b6e2c78d5888"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bdd3888544538a94d7ec99e7c62a0cdd8833609c85f0c23fcb6c5c591aec60ad"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0936c2067a67fb8816c908d5457d63eba3e2b17e515c5fe00e5ee2bace06040"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a78ab89f1924fc11482209f6799a7a3fc74ddc80425a7a3e0e8174af0e9e2301"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7599b65c7af7abbc978dbf345712c60fd596aa5670496561cc10e8a71cebfb2"}, + {file = "llvmlite-0.42.0-cp39-cp39-win_amd64.whl", hash = "sha256:43d65cc4e206c2e902c1004dd5418417c4efa6c1d04df05c6c5675a27e8ca90e"}, + {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, +] + +[[package]] +name = "loguru" +version = "0.6.0" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ + {file = "loguru-0.6.0-py3-none-any.whl", hash = "sha256:4e2414d534a2ab57573365b3e6d0234dfb1d84b68b7f3b948e6fb743860a77c3"}, + {file = "loguru-0.6.0.tar.gz", hash = "sha256:066bd06758d0a513e9836fd9c6b5a75bfb3fd36841f4b996bc60b547a309d41c"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] + +[[package]] +name = "markupsafe" +version = "2.1.5" +description = "Safely add untrusted strings to HTML/XML markup." +optional = true +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win32.whl", hash = "sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"}, + {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, +] + +[[package]] +name = "mkl" +version = "2021.4.0" +description = "Intel® oneAPI Math Kernel Library" +optional = true +python-versions = "*" +files = [ + {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, + {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, + {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, +] + +[package.dependencies] +intel-openmp = "==2021.*" +tbb = "==2021.*" + +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = true +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + +[[package]] +name = "multidict" +version = "6.0.5" +description = "multidict implementation" +optional = true +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, +] + +[[package]] +name = "multiprocess" +version = "0.70.15" +description = "better multiprocessing and multithreading in Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"}, + {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"}, + {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"}, + {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"}, + {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"}, + {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"}, + {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"}, + {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"}, + {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"}, + {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"}, +] + +[package.dependencies] +dill = ">=0.3.7" + +[[package]] +name = "nest-asyncio" +version = "1.6.0" +description = "Patch asyncio to allow nested event loops" +optional = true +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + +[[package]] +name = "networkx" +version = "3.2.1" +description = "Python package for creating and manipulating graphs and networks" +optional = true +python-versions = ">=3.9" +files = [ + {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, + {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, +] + +[package.extras] +default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + +[[package]] +name = "numba" +version = "0.59.1" +description = "compiling Python code using LLVM" +optional = true +python-versions = ">=3.9" +files = [ + {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, + {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, + {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, + {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, + {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, + {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, + {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, + {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, + {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, + {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, + {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, + {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, + {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, +] + +[package.dependencies] +llvmlite = "==0.42.*" +numpy = ">=1.22,<1.27" + +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.20.5" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.5.40" +description = "Nvidia JIT LTO Library" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + +[[package]] +name = "opentelemetry-api" +version = "1.15.0" +description = "OpenTelemetry Python API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_api-1.15.0-py3-none-any.whl", hash = "sha256:e6c2d2e42140fd396e96edf75a7ceb11073f4efb4db87565a431cc9d0f93f2e0"}, + {file = "opentelemetry_api-1.15.0.tar.gz", hash = "sha256:79ab791b4aaad27acc3dc3ba01596db5b5aac2ef75c70622c6038051d6c2cded"}, +] + +[package.dependencies] +deprecated = ">=1.2.6" +setuptools = ">=16.0" + +[[package]] +name = "opentelemetry-exporter-otlp" +version = "1.15.0" +description = "OpenTelemetry Collector Exporters" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_exporter_otlp-1.15.0-py3-none-any.whl", hash = "sha256:79f22748b6a54808a0448093dfa189c8490e729f67c134d4c992533d9393b33e"}, + {file = "opentelemetry_exporter_otlp-1.15.0.tar.gz", hash = "sha256:4f7c49751d9720e2e726e13b0bb958ccade4e29122c305d92c033da432c8d2c5"}, +] + +[package.dependencies] +opentelemetry-exporter-otlp-proto-grpc = "1.15.0" +opentelemetry-exporter-otlp-proto-http = "1.15.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-grpc" +version = "1.15.0" +description = "OpenTelemetry Collector Protobuf over gRPC Exporter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0-py3-none-any.whl", hash = "sha256:c2a5492ba7d140109968135d641d06ce3c5bd73c50665f787526065d57d7fd1d"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.15.0.tar.gz", hash = "sha256:844f2a4bb9bcda34e4eb6fe36765e5031aacb36dc60ed88c90fc246942ea26e7"}, +] + +[package.dependencies] +backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +googleapis-common-protos = ">=1.52,<2.0" +grpcio = ">=1.0.0,<2.0.0" +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-proto = "1.15.0" +opentelemetry-sdk = ">=1.12,<2.0" + +[package.extras] +test = ["pytest-grpc"] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.15.0" +description = "OpenTelemetry Collector Protobuf over HTTP Exporter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_exporter_otlp_proto_http-1.15.0-py3-none-any.whl", hash = "sha256:3ec2a02196c8a54bf5cbf7fe623a5238625638e83b6047a983bdf96e2bbb74c0"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.15.0.tar.gz", hash = "sha256:11b2c814249a49b22f6cca7a06b05701f561d577b747f3660dfd67b6eb9daf9c"}, +] + +[package.dependencies] +backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""} +googleapis-common-protos = ">=1.52,<2.0" +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-proto = "1.15.0" +opentelemetry-sdk = ">=1.12,<2.0" +requests = ">=2.7,<3.0" + +[package.extras] +test = ["responses (==0.22.0)"] + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.36b0" +description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_instrumentation-0.36b0-py3-none-any.whl", hash = "sha256:83ba4ae7d5292b5b33e0f851cc5c76d8f91196b9b3527800fc13855c33383ac2"}, + {file = "opentelemetry_instrumentation-0.36b0.tar.gz", hash = "sha256:e3ddac9b3b93408ef26c8ecbf38f717042977e16381bb4cd329a5b4cf16998cf"}, +] + +[package.dependencies] +opentelemetry-api = ">=1.4,<2.0" +setuptools = ">=16.0" +wrapt = ">=1.0.0,<2.0.0" + +[[package]] +name = "opentelemetry-instrumentation-grpc" +version = "0.36b0" +description = "OpenTelemetry gRPC instrumentation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_instrumentation_grpc-0.36b0-py3-none-any.whl", hash = "sha256:eaa246ed2083c97b13bab2555cb9d170e8433230a31476c4cab8a17fa03380a4"}, + {file = "opentelemetry_instrumentation_grpc-0.36b0.tar.gz", hash = "sha256:dc89447c9eb6ea868970f6c13b4ffdac182cdd5a41dd215a0f5393ca6375be55"}, +] + +[package.dependencies] +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-instrumentation = "0.36b0" +opentelemetry-sdk = ">=1.12,<2.0" +opentelemetry-semantic-conventions = "0.36b0" +wrapt = ">=1.0.0,<2.0.0" + +[package.extras] +instruments = ["grpcio (>=1.27,<2.0)"] +test = ["opentelemetry-instrumentation-grpc[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.36b0)", "protobuf (>=3.13,<4.0)"] + +[[package]] +name = "opentelemetry-proto" +version = "1.15.0" +description = "OpenTelemetry Python Proto" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_proto-1.15.0-py3-none-any.whl", hash = "sha256:044b6d044b4d10530f250856f933442b8753a17f94ae37c207607f733fb9a844"}, + {file = "opentelemetry_proto-1.15.0.tar.gz", hash = "sha256:9c4008e40ac8cab359daac283fbe7002c5c29c77ea2674ad5626a249e64e0101"}, +] + +[package.dependencies] +protobuf = ">=3.19,<5.0" + +[[package]] +name = "opentelemetry-sdk" +version = "1.15.0" +description = "OpenTelemetry Python SDK" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_sdk-1.15.0-py3-none-any.whl", hash = "sha256:555c533e9837766119bbccc7a80458c9971d853a6f1da683a2246cd5e53b4645"}, + {file = "opentelemetry_sdk-1.15.0.tar.gz", hash = "sha256:98dbffcfeebcbff12c0c974292d6ea603180a145904cf838b1fe4d5c99078425"}, +] + +[package.dependencies] +opentelemetry-api = "1.15.0" +opentelemetry-semantic-conventions = "0.36b0" +setuptools = ">=16.0" +typing-extensions = ">=3.7.4" + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.36b0" +description = "OpenTelemetry Semantic Conventions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_semantic_conventions-0.36b0-py3-none-any.whl", hash = "sha256:adc05635e87b9d3e007c9f530eed487fc3ef2177d02f82f674f28ebf9aff8243"}, + {file = "opentelemetry_semantic_conventions-0.36b0.tar.gz", hash = "sha256:829dc221795467d98b773c04096e29be038d77526dc8d6ac76f546fb6279bf01"}, +] + +[[package]] +name = "outlines" +version = "0.0.36" +description = "Probabilistic Generative Model Programming" +optional = true +python-versions = ">=3.8" +files = [ + {file = "outlines-0.0.36-py3-none-any.whl", hash = "sha256:afa02ca5c449c47731fa06af66d13c2f5ee8b30f8b82b4db90e08215d6f111d1"}, + {file = "outlines-0.0.36.tar.gz", hash = "sha256:3cffb43143548cd78c6061990feb461cffd5479999391b8390471ea839c2d46e"}, +] + +[package.dependencies] +cloudpickle = "*" +diskcache = "*" +interegular = "*" +jinja2 = "*" +joblib = "*" +jsonschema = "*" +lark = "*" +nest-asyncio = "*" +numba = "*" +numpy = "*" +pydantic = ">=2.0" +referencing = "*" +requests = "*" +scipy = "*" +torch = ">=2.1.0" +transformers = "*" + +[package.extras] +serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"] +test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "openai (>=1.0.0)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] + +[[package]] +name = "packaging" +version = "24.0" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, +] + +[[package]] +name = "pandas" +version = "2.2.2" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = true +python-versions = ">=3.9" +files = [ + {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, + {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, + {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, + {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.7" + +[package.extras] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] + +[[package]] +name = "peft" +version = "0.10.0" +description = "Parameter-Efficient Fine-Tuning (PEFT)" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "peft-0.10.0-py3-none-any.whl", hash = "sha256:d5249c97e818d3e31f92553c73c2953acd0ec12649b8b749afff7152cbc86cbb"}, + {file = "peft-0.10.0.tar.gz", hash = "sha256:36a7628c15f88d37abb26cfc74c22468f9037ee02e9c9b65de943cfe7c672049"}, +] + +[package.dependencies] +accelerate = ">=0.21.0" +huggingface-hub = ">=0.17.0" +numpy = ">=1.17" +packaging = ">=20.0" +psutil = "*" +pyyaml = "*" +safetensors = "*" +torch = ">=1.13.0" +tqdm = "*" +transformers = "*" + +[package.extras] +dev = ["black", "hf-doc-builder", "ruff (>=0.2.1,<0.3.0)"] +docs-specific = ["black", "hf-doc-builder"] +quality = ["black", "hf-doc-builder", "ruff (>=0.2.1,<0.3.0)"] +test = ["black", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.2.1,<0.3.0)", "scipy"] + +[[package]] +name = "pillow" +version = "10.3.0" +description = "Python Imaging Library (Fork)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"}, + {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"}, + {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"}, + {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"}, + {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"}, + {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"}, + {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"}, + {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"}, + {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"}, + {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"}, + {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"}, + {file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"}, + {file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"}, + {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"}, + {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"}, + {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"}, + {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +typing = ["typing-extensions"] +xmp = ["defusedxml"] + +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "prometheus-client" +version = "0.20.0" +description = "Python client for the Prometheus monitoring system." +optional = false +python-versions = ">=3.8" +files = [ + {file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"}, + {file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"}, +] + +[package.extras] +twisted = ["twisted"] + +[[package]] +name = "protobuf" +version = "4.25.3" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"}, + {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"}, + {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"}, + {file = "protobuf-4.25.3-cp38-cp38-win32.whl", hash = "sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2"}, + {file = "protobuf-4.25.3-cp38-cp38-win_amd64.whl", hash = "sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4"}, + {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"}, + {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"}, + {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"}, + {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, +] + +[[package]] +name = "psutil" +version = "5.9.8" +description = "Cross-platform lib for process and system monitoring in Python." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"}, + {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"}, + {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"}, + {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"}, + {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"}, + {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"}, + {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"}, + {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, + {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, + {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, + {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"}, + {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"}, + {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, + {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, + {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, + {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +description = "Get CPU info with pure Python" +optional = false +python-versions = "*" +files = [ + {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"}, + {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, +] + +[[package]] +name = "pyarrow" +version = "16.1.0" +description = "Python library for Apache Arrow" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, + {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, + {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, + {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, + {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, + {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, + {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, + {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + +[[package]] +name = "pydantic" +version = "2.7.2" +description = "Data validation using Python type hints" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.7.2-py3-none-any.whl", hash = "sha256:834ab954175f94e6e68258537dc49402c4a5e9d0409b9f1b86b7e934a8372de7"}, + {file = "pydantic-2.7.2.tar.gz", hash = "sha256:71b2945998f9c9b7919a45bde9a50397b289937d215ae141c1d0903ba7149fd7"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.18.3" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.18.3" +description = "Core functionality for Pydantic validation and serialization" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.18.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:744697428fcdec6be5670460b578161d1ffe34743a5c15656be7ea82b008197c"}, + {file = "pydantic_core-2.18.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37b40c05ced1ba4218b14986fe6f283d22e1ae2ff4c8e28881a70fb81fbfcda7"}, + {file = "pydantic_core-2.18.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:544a9a75622357076efb6b311983ff190fbfb3c12fc3a853122b34d3d358126c"}, + {file = "pydantic_core-2.18.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e2e253af04ceaebde8eb201eb3f3e3e7e390f2d275a88300d6a1959d710539e2"}, + {file = "pydantic_core-2.18.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:855ec66589c68aa367d989da5c4755bb74ee92ccad4fdb6af942c3612c067e34"}, + {file = "pydantic_core-2.18.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d3e42bb54e7e9d72c13ce112e02eb1b3b55681ee948d748842171201a03a98a"}, + {file = "pydantic_core-2.18.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6ac9ffccc9d2e69d9fba841441d4259cb668ac180e51b30d3632cd7abca2b9b"}, + {file = "pydantic_core-2.18.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c56eca1686539fa0c9bda992e7bd6a37583f20083c37590413381acfc5f192d6"}, + {file = "pydantic_core-2.18.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:17954d784bf8abfc0ec2a633108207ebc4fa2df1a0e4c0c3ccbaa9bb01d2c426"}, + {file = "pydantic_core-2.18.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:98ed737567d8f2ecd54f7c8d4f8572ca7c7921ede93a2e52939416170d357812"}, + {file = "pydantic_core-2.18.3-cp310-none-win32.whl", hash = "sha256:9f9e04afebd3ed8c15d67a564ed0a34b54e52136c6d40d14c5547b238390e779"}, + {file = "pydantic_core-2.18.3-cp310-none-win_amd64.whl", hash = "sha256:45e4ffbae34f7ae30d0047697e724e534a7ec0a82ef9994b7913a412c21462a0"}, + {file = "pydantic_core-2.18.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b9ebe8231726c49518b16b237b9fe0d7d361dd221302af511a83d4ada01183ab"}, + {file = "pydantic_core-2.18.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b8e20e15d18bf7dbb453be78a2d858f946f5cdf06c5072453dace00ab652e2b2"}, + {file = "pydantic_core-2.18.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0d9ff283cd3459fa0bf9b0256a2b6f01ac1ff9ffb034e24457b9035f75587cb"}, + {file = "pydantic_core-2.18.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f7ef5f0ebb77ba24c9970da18b771711edc5feaf00c10b18461e0f5f5949231"}, + {file = "pydantic_core-2.18.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73038d66614d2e5cde30435b5afdced2b473b4c77d4ca3a8624dd3e41a9c19be"}, + {file = "pydantic_core-2.18.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6afd5c867a74c4d314c557b5ea9520183fadfbd1df4c2d6e09fd0d990ce412cd"}, + {file = "pydantic_core-2.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd7df92f28d351bb9f12470f4c533cf03d1b52ec5a6e5c58c65b183055a60106"}, + {file = "pydantic_core-2.18.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:80aea0ffeb1049336043d07799eace1c9602519fb3192916ff525b0287b2b1e4"}, + {file = "pydantic_core-2.18.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:aaee40f25bba38132e655ffa3d1998a6d576ba7cf81deff8bfa189fb43fd2bbe"}, + {file = "pydantic_core-2.18.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9128089da8f4fe73f7a91973895ebf2502539d627891a14034e45fb9e707e26d"}, + {file = "pydantic_core-2.18.3-cp311-none-win32.whl", hash = "sha256:fec02527e1e03257aa25b1a4dcbe697b40a22f1229f5d026503e8b7ff6d2eda7"}, + {file = "pydantic_core-2.18.3-cp311-none-win_amd64.whl", hash = "sha256:58ff8631dbab6c7c982e6425da8347108449321f61fe427c52ddfadd66642af7"}, + {file = "pydantic_core-2.18.3-cp311-none-win_arm64.whl", hash = "sha256:3fc1c7f67f34c6c2ef9c213e0f2a351797cda98249d9ca56a70ce4ebcaba45f4"}, + {file = "pydantic_core-2.18.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f0928cde2ae416a2d1ebe6dee324709c6f73e93494d8c7aea92df99aab1fc40f"}, + {file = "pydantic_core-2.18.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bee9bb305a562f8b9271855afb6ce00223f545de3d68560b3c1649c7c5295e9"}, + {file = "pydantic_core-2.18.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e862823be114387257dacbfa7d78547165a85d7add33b446ca4f4fae92c7ff5c"}, + {file = "pydantic_core-2.18.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a36f78674cbddc165abab0df961b5f96b14461d05feec5e1f78da58808b97e7"}, + {file = "pydantic_core-2.18.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba905d184f62e7ddbb7a5a751d8a5c805463511c7b08d1aca4a3e8c11f2e5048"}, + {file = "pydantic_core-2.18.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7fdd362f6a586e681ff86550b2379e532fee63c52def1c666887956748eaa326"}, + {file = "pydantic_core-2.18.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24b214b7ee3bd3b865e963dbed0f8bc5375f49449d70e8d407b567af3222aae4"}, + {file = "pydantic_core-2.18.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:691018785779766127f531674fa82bb368df5b36b461622b12e176c18e119022"}, + {file = "pydantic_core-2.18.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:60e4c625e6f7155d7d0dcac151edf5858102bc61bf959d04469ca6ee4e8381bd"}, + {file = "pydantic_core-2.18.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4e651e47d981c1b701dcc74ab8fec5a60a5b004650416b4abbef13db23bc7be"}, + {file = "pydantic_core-2.18.3-cp312-none-win32.whl", hash = "sha256:ffecbb5edb7f5ffae13599aec33b735e9e4c7676ca1633c60f2c606beb17efc5"}, + {file = "pydantic_core-2.18.3-cp312-none-win_amd64.whl", hash = "sha256:2c8333f6e934733483c7eddffdb094c143b9463d2af7e6bd85ebcb2d4a1b82c6"}, + {file = "pydantic_core-2.18.3-cp312-none-win_arm64.whl", hash = "sha256:7a20dded653e516a4655f4c98e97ccafb13753987434fe7cf044aa25f5b7d417"}, + {file = "pydantic_core-2.18.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:eecf63195be644b0396f972c82598cd15693550f0ff236dcf7ab92e2eb6d3522"}, + {file = "pydantic_core-2.18.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c44efdd3b6125419c28821590d7ec891c9cb0dff33a7a78d9d5c8b6f66b9702"}, + {file = "pydantic_core-2.18.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e59fca51ffbdd1638b3856779342ed69bcecb8484c1d4b8bdb237d0eb5a45e2"}, + {file = "pydantic_core-2.18.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70cf099197d6b98953468461d753563b28e73cf1eade2ffe069675d2657ed1d5"}, + {file = "pydantic_core-2.18.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63081a49dddc6124754b32a3774331467bfc3d2bd5ff8f10df36a95602560361"}, + {file = "pydantic_core-2.18.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:370059b7883485c9edb9655355ff46d912f4b03b009d929220d9294c7fd9fd60"}, + {file = "pydantic_core-2.18.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a64faeedfd8254f05f5cf6fc755023a7e1606af3959cfc1a9285744cc711044"}, + {file = "pydantic_core-2.18.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19d2e725de0f90d8671f89e420d36c3dd97639b98145e42fcc0e1f6d492a46dc"}, + {file = "pydantic_core-2.18.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:67bc078025d70ec5aefe6200ef094576c9d86bd36982df1301c758a9fff7d7f4"}, + {file = "pydantic_core-2.18.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:adf952c3f4100e203cbaf8e0c907c835d3e28f9041474e52b651761dc248a3c0"}, + {file = "pydantic_core-2.18.3-cp38-none-win32.whl", hash = "sha256:9a46795b1f3beb167eaee91736d5d17ac3a994bf2215a996aed825a45f897558"}, + {file = "pydantic_core-2.18.3-cp38-none-win_amd64.whl", hash = "sha256:200ad4e3133cb99ed82342a101a5abf3d924722e71cd581cc113fe828f727fbc"}, + {file = "pydantic_core-2.18.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:304378b7bf92206036c8ddd83a2ba7b7d1a5b425acafff637172a3aa72ad7083"}, + {file = "pydantic_core-2.18.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c826870b277143e701c9ccf34ebc33ddb4d072612683a044e7cce2d52f6c3fef"}, + {file = "pydantic_core-2.18.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e201935d282707394f3668380e41ccf25b5794d1b131cdd96b07f615a33ca4b1"}, + {file = "pydantic_core-2.18.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5560dda746c44b48bf82b3d191d74fe8efc5686a9ef18e69bdabccbbb9ad9442"}, + {file = "pydantic_core-2.18.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b32c2a1f8032570842257e4c19288eba9a2bba4712af542327de9a1204faff8"}, + {file = "pydantic_core-2.18.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:929c24e9dea3990bc8bcd27c5f2d3916c0c86f5511d2caa69e0d5290115344a9"}, + {file = "pydantic_core-2.18.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1a8376fef60790152564b0eab376b3e23dd6e54f29d84aad46f7b264ecca943"}, + {file = "pydantic_core-2.18.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dccf3ef1400390ddd1fb55bf0632209d39140552d068ee5ac45553b556780e06"}, + {file = "pydantic_core-2.18.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:41dbdcb0c7252b58fa931fec47937edb422c9cb22528f41cb8963665c372caf6"}, + {file = "pydantic_core-2.18.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:666e45cf071669fde468886654742fa10b0e74cd0fa0430a46ba6056b24fb0af"}, + {file = "pydantic_core-2.18.3-cp39-none-win32.whl", hash = "sha256:f9c08cabff68704a1b4667d33f534d544b8a07b8e5d039c37067fceb18789e78"}, + {file = "pydantic_core-2.18.3-cp39-none-win_amd64.whl", hash = "sha256:4afa5f5973e8572b5c0dcb4e2d4fda7890e7cd63329bd5cc3263a25c92ef0026"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:77319771a026f7c7d29c6ebc623de889e9563b7087911b46fd06c044a12aa5e9"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:df11fa992e9f576473038510d66dd305bcd51d7dd508c163a8c8fe148454e059"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d531076bdfb65af593326ffd567e6ab3da145020dafb9187a1d131064a55f97c"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d33ce258e4e6e6038f2b9e8b8a631d17d017567db43483314993b3ca345dcbbb"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1f9cd7f5635b719939019be9bda47ecb56e165e51dd26c9a217a433e3d0d59a9"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cd4a032bb65cc132cae1fe3e52877daecc2097965cd3914e44fbd12b00dae7c5"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82f2718430098bcdf60402136c845e4126a189959d103900ebabb6774a5d9fdb"}, + {file = "pydantic_core-2.18.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c0037a92cf0c580ed14e10953cdd26528e8796307bb8bb312dc65f71547df04d"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b95a0972fac2b1ff3c94629fc9081b16371dad870959f1408cc33b2f78ad347a"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a62e437d687cc148381bdd5f51e3e81f5b20a735c55f690c5be94e05da2b0d5c"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b367a73a414bbb08507da102dc2cde0fa7afe57d09b3240ce82a16d608a7679c"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ecce4b2360aa3f008da3327d652e74a0e743908eac306198b47e1c58b03dd2b"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bd4435b8d83f0c9561a2a9585b1de78f1abb17cb0cef5f39bf6a4b47d19bafe3"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:616221a6d473c5b9aa83fa8982745441f6a4a62a66436be9445c65f241b86c94"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7e6382ce89a92bc1d0c0c5edd51e931432202b9080dc921d8d003e616402efd1"}, + {file = "pydantic_core-2.18.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ff58f379345603d940e461eae474b6bbb6dab66ed9a851ecd3cb3709bf4dcf6a"}, + {file = "pydantic_core-2.18.3.tar.gz", hash = "sha256:432e999088d85c8f36b9a3f769a8e2b57aabd817bbb729a90d1fe7f18f6f1f39"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pytest" +version = "7.4.4" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, + {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = true +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytz" +version = "2024.1" +description = "World timezone definitions, modern and historical" +optional = true +python-versions = "*" +files = [ + {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, + {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, +] + +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] +name = "referencing" +version = "0.35.1" +description = "JSON Referencing + Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + +[[package]] +name = "regex" +version = "2024.5.15" +description = "Alternative regular expression module, to replace re." +optional = false +python-versions = ">=3.8" +files = [ + {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a81e3cfbae20378d75185171587cbf756015ccb14840702944f014e0d93ea09f"}, + {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b59138b219ffa8979013be7bc85bb60c6f7b7575df3d56dc1e403a438c7a3f6"}, + {file = "regex-2024.5.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0bd000c6e266927cb7a1bc39d55be95c4b4f65c5be53e659537537e019232b1"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eaa7ddaf517aa095fa8da0b5015c44d03da83f5bd49c87961e3c997daed0de7"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba68168daedb2c0bab7fd7e00ced5ba90aebf91024dea3c88ad5063c2a562cca"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6e8d717bca3a6e2064fc3a08df5cbe366369f4b052dcd21b7416e6d71620dca1"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1337b7dbef9b2f71121cdbf1e97e40de33ff114801263b275aafd75303bd62b5"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9ebd0a36102fcad2f03696e8af4ae682793a5d30b46c647eaf280d6cfb32796"}, + {file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9efa1a32ad3a3ea112224897cdaeb6aa00381627f567179c0314f7b65d354c62"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1595f2d10dff3d805e054ebdc41c124753631b6a471b976963c7b28543cf13b0"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b802512f3e1f480f41ab5f2cfc0e2f761f08a1f41092d6718868082fc0d27143"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a0981022dccabca811e8171f913de05720590c915b033b7e601f35ce4ea7019f"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:19068a6a79cf99a19ccefa44610491e9ca02c2be3305c7760d3831d38a467a6f"}, + {file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1b5269484f6126eee5e687785e83c6b60aad7663dafe842b34691157e5083e53"}, + {file = "regex-2024.5.15-cp310-cp310-win32.whl", hash = "sha256:ada150c5adfa8fbcbf321c30c751dc67d2f12f15bd183ffe4ec7cde351d945b3"}, + {file = "regex-2024.5.15-cp310-cp310-win_amd64.whl", hash = "sha256:ac394ff680fc46b97487941f5e6ae49a9f30ea41c6c6804832063f14b2a5a145"}, + {file = "regex-2024.5.15-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f5b1dff3ad008dccf18e652283f5e5339d70bf8ba7c98bf848ac33db10f7bc7a"}, + {file = "regex-2024.5.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c6a2b494a76983df8e3d3feea9b9ffdd558b247e60b92f877f93a1ff43d26656"}, + {file = "regex-2024.5.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a32b96f15c8ab2e7d27655969a23895eb799de3665fa94349f3b2fbfd547236f"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10002e86e6068d9e1c91eae8295ef690f02f913c57db120b58fdd35a6bb1af35"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec54d5afa89c19c6dd8541a133be51ee1017a38b412b1321ccb8d6ddbeb4cf7d"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10e4ce0dca9ae7a66e6089bb29355d4432caed736acae36fef0fdd7879f0b0cb"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e507ff1e74373c4d3038195fdd2af30d297b4f0950eeda6f515ae3d84a1770f"}, + {file = "regex-2024.5.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1f059a4d795e646e1c37665b9d06062c62d0e8cc3c511fe01315973a6542e40"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0721931ad5fe0dda45d07f9820b90b2148ccdd8e45bb9e9b42a146cb4f695649"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:833616ddc75ad595dee848ad984d067f2f31be645d603e4d158bba656bbf516c"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:287eb7f54fc81546346207c533ad3c2c51a8d61075127d7f6d79aaf96cdee890"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:19dfb1c504781a136a80ecd1fff9f16dddf5bb43cec6871778c8a907a085bb3d"}, + {file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:119af6e56dce35e8dfb5222573b50c89e5508d94d55713c75126b753f834de68"}, + {file = "regex-2024.5.15-cp311-cp311-win32.whl", hash = "sha256:1c1c174d6ec38d6c8a7504087358ce9213d4332f6293a94fbf5249992ba54efa"}, + {file = "regex-2024.5.15-cp311-cp311-win_amd64.whl", hash = "sha256:9e717956dcfd656f5055cc70996ee2cc82ac5149517fc8e1b60261b907740201"}, + {file = "regex-2024.5.15-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:632b01153e5248c134007209b5c6348a544ce96c46005d8456de1d552455b014"}, + {file = "regex-2024.5.15-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e64198f6b856d48192bf921421fdd8ad8eb35e179086e99e99f711957ffedd6e"}, + {file = "regex-2024.5.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68811ab14087b2f6e0fc0c2bae9ad689ea3584cad6917fc57be6a48bbd012c49"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8ec0c2fea1e886a19c3bee0cd19d862b3aa75dcdfb42ebe8ed30708df64687a"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0c0c0003c10f54a591d220997dd27d953cd9ccc1a7294b40a4be5312be8797b"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2431b9e263af1953c55abbd3e2efca67ca80a3de8a0437cb58e2421f8184717a"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a605586358893b483976cffc1723fb0f83e526e8f14c6e6614e75919d9862cf"}, + {file = "regex-2024.5.15-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391d7f7f1e409d192dba8bcd42d3e4cf9e598f3979cdaed6ab11288da88cb9f2"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9ff11639a8d98969c863d4617595eb5425fd12f7c5ef6621a4b74b71ed8726d5"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4eee78a04e6c67e8391edd4dad3279828dd66ac4b79570ec998e2155d2e59fd5"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8fe45aa3f4aa57faabbc9cb46a93363edd6197cbc43523daea044e9ff2fea83e"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:d0a3d8d6acf0c78a1fff0e210d224b821081330b8524e3e2bc5a68ef6ab5803d"}, + {file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c486b4106066d502495b3025a0a7251bf37ea9540433940a23419461ab9f2a80"}, + {file = "regex-2024.5.15-cp312-cp312-win32.whl", hash = "sha256:c49e15eac7c149f3670b3e27f1f28a2c1ddeccd3a2812cba953e01be2ab9b5fe"}, + {file = "regex-2024.5.15-cp312-cp312-win_amd64.whl", hash = "sha256:673b5a6da4557b975c6c90198588181029c60793835ce02f497ea817ff647cb2"}, + {file = "regex-2024.5.15-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:87e2a9c29e672fc65523fb47a90d429b70ef72b901b4e4b1bd42387caf0d6835"}, + {file = "regex-2024.5.15-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c3bea0ba8b73b71b37ac833a7f3fd53825924165da6a924aec78c13032f20850"}, + {file = "regex-2024.5.15-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bfc4f82cabe54f1e7f206fd3d30fda143f84a63fe7d64a81558d6e5f2e5aaba9"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5bb9425fe881d578aeca0b2b4b3d314ec88738706f66f219c194d67179337cb"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64c65783e96e563103d641760664125e91bd85d8e49566ee560ded4da0d3e704"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf2430df4148b08fb4324b848672514b1385ae3807651f3567871f130a728cc3"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5397de3219a8b08ae9540c48f602996aa6b0b65d5a61683e233af8605c42b0f2"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:455705d34b4154a80ead722f4f185b04c4237e8e8e33f265cd0798d0e44825fa"}, + {file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2b6f1b3bb6f640c1a92be3bbfbcb18657b125b99ecf141fb3310b5282c7d4ed"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3ad070b823ca5890cab606c940522d05d3d22395d432f4aaaf9d5b1653e47ced"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5b5467acbfc153847d5adb21e21e29847bcb5870e65c94c9206d20eb4e99a384"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e6662686aeb633ad65be2a42b4cb00178b3fbf7b91878f9446075c404ada552f"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:2b4c884767504c0e2401babe8b5b7aea9148680d2e157fa28f01529d1f7fcf67"}, + {file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:3cd7874d57f13bf70078f1ff02b8b0aa48d5b9ed25fc48547516c6aba36f5741"}, + {file = "regex-2024.5.15-cp38-cp38-win32.whl", hash = "sha256:e4682f5ba31f475d58884045c1a97a860a007d44938c4c0895f41d64481edbc9"}, + {file = "regex-2024.5.15-cp38-cp38-win_amd64.whl", hash = "sha256:d99ceffa25ac45d150e30bd9ed14ec6039f2aad0ffa6bb87a5936f5782fc1569"}, + {file = "regex-2024.5.15-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13cdaf31bed30a1e1c2453ef6015aa0983e1366fad2667657dbcac7b02f67133"}, + {file = "regex-2024.5.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cac27dcaa821ca271855a32188aa61d12decb6fe45ffe3e722401fe61e323cd1"}, + {file = "regex-2024.5.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7dbe2467273b875ea2de38ded4eba86cbcbc9a1a6d0aa11dcf7bd2e67859c435"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64f18a9a3513a99c4bef0e3efd4c4a5b11228b48aa80743be822b71e132ae4f5"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d347a741ea871c2e278fde6c48f85136c96b8659b632fb57a7d1ce1872547600"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1878b8301ed011704aea4c806a3cadbd76f84dece1ec09cc9e4dc934cfa5d4da"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4babf07ad476aaf7830d77000874d7611704a7fcf68c9c2ad151f5d94ae4bfc4"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35cb514e137cb3488bce23352af3e12fb0dbedd1ee6e60da053c69fb1b29cc6c"}, + {file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cdd09d47c0b2efee9378679f8510ee6955d329424c659ab3c5e3a6edea696294"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:72d7a99cd6b8f958e85fc6ca5b37c4303294954eac1376535b03c2a43eb72629"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a094801d379ab20c2135529948cb84d417a2169b9bdceda2a36f5f10977ebc16"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c0c18345010870e58238790a6779a1219b4d97bd2e77e1140e8ee5d14df071aa"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:16093f563098448ff6b1fa68170e4acbef94e6b6a4e25e10eae8598bb1694b5d"}, + {file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e38a7d4e8f633a33b4c7350fbd8bad3b70bf81439ac67ac38916c4a86b465456"}, + {file = "regex-2024.5.15-cp39-cp39-win32.whl", hash = "sha256:71a455a3c584a88f654b64feccc1e25876066c4f5ef26cd6dd711308aa538694"}, + {file = "regex-2024.5.15-cp39-cp39-win_amd64.whl", hash = "sha256:cab12877a9bdafde5500206d1020a584355a97884dfd388af3699e9137bf7388"}, + {file = "regex-2024.5.15.tar.gz", hash = "sha256:d3ee02d9e5f482cc8309134a91eeaacbdd2261ba111b0fef3748eeb4913e6a2c"}, +] + +[[package]] +name = "requests" +version = "2.32.3" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.8" +files = [ + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "rpds-py" +version = "0.18.1" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = true +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.18.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d31dea506d718693b6b2cffc0648a8929bdc51c70a311b2770f09611caa10d53"}, + {file = "rpds_py-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:732672fbc449bab754e0b15356c077cc31566df874964d4801ab14f71951ea80"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a98a1f0552b5f227a3d6422dbd61bc6f30db170939bd87ed14f3c339aa6c7c9"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7f1944ce16401aad1e3f7d312247b3d5de7981f634dc9dfe90da72b87d37887d"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38e14fb4e370885c4ecd734f093a2225ee52dc384b86fa55fe3f74638b2cfb09"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08d74b184f9ab6289b87b19fe6a6d1a97fbfea84b8a3e745e87a5de3029bf944"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d70129cef4a8d979caa37e7fe957202e7eee8ea02c5e16455bc9808a59c6b2f0"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0bb20e3a11bd04461324a6a798af34d503f8d6f1aa3d2aa8901ceaf039176d"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81c5196a790032e0fc2464c0b4ab95f8610f96f1f2fa3d4deacce6a79852da60"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f3027be483868c99b4985fda802a57a67fdf30c5d9a50338d9db646d590198da"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d44607f98caa2961bab4fa3c4309724b185b464cdc3ba6f3d7340bac3ec97cc1"}, + {file = "rpds_py-0.18.1-cp310-none-win32.whl", hash = "sha256:c273e795e7a0f1fddd46e1e3cb8be15634c29ae8ff31c196debb620e1edb9333"}, + {file = "rpds_py-0.18.1-cp310-none-win_amd64.whl", hash = "sha256:8352f48d511de5f973e4f2f9412736d7dea76c69faa6d36bcf885b50c758ab9a"}, + {file = "rpds_py-0.18.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6b5ff7e1d63a8281654b5e2896d7f08799378e594f09cf3674e832ecaf396ce8"}, + {file = "rpds_py-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8927638a4d4137a289e41d0fd631551e89fa346d6dbcfc31ad627557d03ceb6d"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:154bf5c93d79558b44e5b50cc354aa0459e518e83677791e6adb0b039b7aa6a7"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:07f2139741e5deb2c5154a7b9629bc5aa48c766b643c1a6750d16f865a82c5fc"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c7672e9fba7425f79019db9945b16e308ed8bc89348c23d955c8c0540da0a07"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:489bdfe1abd0406eba6b3bb4fdc87c7fa40f1031de073d0cfb744634cc8fa261"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c20f05e8e3d4fc76875fc9cb8cf24b90a63f5a1b4c5b9273f0e8225e169b100"}, + {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:967342e045564cef76dfcf1edb700b1e20838d83b1aa02ab313e6a497cf923b8"}, + {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2cc7c1a47f3a63282ab0f422d90ddac4aa3034e39fc66a559ab93041e6505da7"}, + {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f7afbfee1157e0f9376c00bb232e80a60e59ed716e3211a80cb8506550671e6e"}, + {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9e6934d70dc50f9f8ea47081ceafdec09245fd9f6032669c3b45705dea096b88"}, + {file = "rpds_py-0.18.1-cp311-none-win32.whl", hash = "sha256:c69882964516dc143083d3795cb508e806b09fc3800fd0d4cddc1df6c36e76bb"}, + {file = "rpds_py-0.18.1-cp311-none-win_amd64.whl", hash = "sha256:70a838f7754483bcdc830444952fd89645569e7452e3226de4a613a4c1793fb2"}, + {file = "rpds_py-0.18.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3dd3cd86e1db5aadd334e011eba4e29d37a104b403e8ca24dcd6703c68ca55b3"}, + {file = "rpds_py-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:05f3d615099bd9b13ecf2fc9cf2d839ad3f20239c678f461c753e93755d629ee"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b2b771b13eee8729a5049c976197ff58a27a3829c018a04341bcf1ae409b2b"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee17cd26b97d537af8f33635ef38be873073d516fd425e80559f4585a7b90c43"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b646bf655b135ccf4522ed43d6902af37d3f5dbcf0da66c769a2b3938b9d8184"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19ba472b9606c36716062c023afa2484d1e4220548751bda14f725a7de17b4f6"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e30ac5e329098903262dc5bdd7e2086e0256aa762cc8b744f9e7bf2a427d3f8"}, + {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d58ad6317d188c43750cb76e9deacf6051d0f884d87dc6518e0280438648a9ac"}, + {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e1735502458621921cee039c47318cb90b51d532c2766593be6207eec53e5c4c"}, + {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f5bab211605d91db0e2995a17b5c6ee5edec1270e46223e513eaa20da20076ac"}, + {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2fc24a329a717f9e2448f8cd1f960f9dac4e45b6224d60734edeb67499bab03a"}, + {file = "rpds_py-0.18.1-cp312-none-win32.whl", hash = "sha256:1805d5901779662d599d0e2e4159d8a82c0b05faa86ef9222bf974572286b2b6"}, + {file = "rpds_py-0.18.1-cp312-none-win_amd64.whl", hash = "sha256:720edcb916df872d80f80a1cc5ea9058300b97721efda8651efcd938a9c70a72"}, + {file = "rpds_py-0.18.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:c827576e2fa017a081346dce87d532a5310241648eb3700af9a571a6e9fc7e74"}, + {file = "rpds_py-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:aa3679e751408d75a0b4d8d26d6647b6d9326f5e35c00a7ccd82b78ef64f65f8"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0abeee75434e2ee2d142d650d1e54ac1f8b01e6e6abdde8ffd6eeac6e9c38e20"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed402d6153c5d519a0faf1bb69898e97fb31613b49da27a84a13935ea9164dfc"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:338dee44b0cef8b70fd2ef54b4e09bb1b97fc6c3a58fea5db6cc083fd9fc2724"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7750569d9526199c5b97e5a9f8d96a13300950d910cf04a861d96f4273d5b104"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:607345bd5912aacc0c5a63d45a1f73fef29e697884f7e861094e443187c02be5"}, + {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:207c82978115baa1fd8d706d720b4a4d2b0913df1c78c85ba73fe6c5804505f0"}, + {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6d1e42d2735d437e7e80bab4d78eb2e459af48c0a46e686ea35f690b93db792d"}, + {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5463c47c08630007dc0fe99fb480ea4f34a89712410592380425a9b4e1611d8e"}, + {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:06d218939e1bf2ca50e6b0ec700ffe755e5216a8230ab3e87c059ebb4ea06afc"}, + {file = "rpds_py-0.18.1-cp38-none-win32.whl", hash = "sha256:312fe69b4fe1ffbe76520a7676b1e5ac06ddf7826d764cc10265c3b53f96dbe9"}, + {file = "rpds_py-0.18.1-cp38-none-win_amd64.whl", hash = "sha256:9437ca26784120a279f3137ee080b0e717012c42921eb07861b412340f85bae2"}, + {file = "rpds_py-0.18.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:19e515b78c3fc1039dd7da0a33c28c3154458f947f4dc198d3c72db2b6b5dc93"}, + {file = "rpds_py-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7b28c5b066bca9a4eb4e2f2663012debe680f097979d880657f00e1c30875a0"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:673fdbbf668dd958eff750e500495ef3f611e2ecc209464f661bc82e9838991e"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d960de62227635d2e61068f42a6cb6aae91a7fe00fca0e3aeed17667c8a34611"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352a88dc7892f1da66b6027af06a2e7e5d53fe05924cc2cfc56495b586a10b72"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4e0ee01ad8260184db21468a6e1c37afa0529acc12c3a697ee498d3c2c4dcaf3"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4c39ad2f512b4041343ea3c7894339e4ca7839ac38ca83d68a832fc8b3748ab"}, + {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aaa71ee43a703c321906813bb252f69524f02aa05bf4eec85f0c41d5d62d0f4c"}, + {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:6cd8098517c64a85e790657e7b1e509b9fe07487fd358e19431cb120f7d96338"}, + {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4adec039b8e2928983f885c53b7cc4cda8965b62b6596501a0308d2703f8af1b"}, + {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32b7daaa3e9389db3695964ce8e566e3413b0c43e3394c05e4b243a4cd7bef26"}, + {file = "rpds_py-0.18.1-cp39-none-win32.whl", hash = "sha256:2625f03b105328729f9450c8badda34d5243231eef6535f80064d57035738360"}, + {file = "rpds_py-0.18.1-cp39-none-win_amd64.whl", hash = "sha256:bf18932d0003c8c4d51a39f244231986ab23ee057d235a12b2684ea26a353590"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cbfbea39ba64f5e53ae2915de36f130588bba71245b418060ec3330ebf85678e"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a3d456ff2a6a4d2adcdf3c1c960a36f4fd2fec6e3b4902a42a384d17cf4e7a65"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7700936ef9d006b7ef605dc53aa364da2de5a3aa65516a1f3ce73bf82ecfc7ae"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:51584acc5916212e1bf45edd17f3a6b05fe0cbb40482d25e619f824dccb679de"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:942695a206a58d2575033ff1e42b12b2aece98d6003c6bc739fbf33d1773b12f"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b906b5f58892813e5ba5c6056d6a5ad08f358ba49f046d910ad992196ea61397"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6f8e3fecca256fefc91bb6765a693d96692459d7d4c644660a9fff32e517843"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7732770412bab81c5a9f6d20aeb60ae943a9b36dcd990d876a773526468e7163"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:bd1105b50ede37461c1d51b9698c4f4be6e13e69a908ab7751e3807985fc0346"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:618916f5535784960f3ecf8111581f4ad31d347c3de66d02e728de460a46303c"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:17c6d2155e2423f7e79e3bb18151c686d40db42d8645e7977442170c360194d4"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c4c4c3f878df21faf5fac86eda32671c27889e13570645a9eea0a1abdd50922"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:fab6ce90574645a0d6c58890e9bcaac8d94dff54fb51c69e5522a7358b80ab64"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:531796fb842b53f2695e94dc338929e9f9dbf473b64710c28af5a160b2a8927d"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:740884bc62a5e2bbb31e584f5d23b32320fd75d79f916f15a788d527a5e83644"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:998125738de0158f088aef3cb264a34251908dd2e5d9966774fdab7402edfab7"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2be6e9dd4111d5b31ba3b74d17da54a8319d8168890fbaea4b9e5c3de630ae5"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0cee71bc618cd93716f3c1bf56653740d2d13ddbd47673efa8bf41435a60daa"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2c3caec4ec5cd1d18e5dd6ae5194d24ed12785212a90b37f5f7f06b8bedd7139"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:27bba383e8c5231cd559affe169ca0b96ec78d39909ffd817f28b166d7ddd4d8"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:a888e8bdb45916234b99da2d859566f1e8a1d2275a801bb8e4a9644e3c7e7909"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6031b25fb1b06327b43d841f33842b383beba399884f8228a6bb3df3088485ff"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48c2faaa8adfacefcbfdb5f2e2e7bdad081e5ace8d182e5f4ade971f128e6bb3"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:d85164315bd68c0806768dc6bb0429c6f95c354f87485ee3593c4f6b14def2bd"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6afd80f6c79893cfc0574956f78a0add8c76e3696f2d6a15bca2c66c415cf2d4"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa242ac1ff583e4ec7771141606aafc92b361cd90a05c30d93e343a0c2d82a89"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21be4770ff4e08698e1e8e0bce06edb6ea0626e7c8f560bc08222880aca6a6f"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c45a639e93a0c5d4b788b2613bd637468edd62f8f95ebc6fcc303d58ab3f0a8"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910e71711d1055b2768181efa0a17537b2622afeb0424116619817007f8a2b10"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b9bb1f182a97880f6078283b3505a707057c42bf55d8fca604f70dedfdc0772a"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d54f74f40b1f7aaa595a02ff42ef38ca654b1469bef7d52867da474243cc633"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:8d2e182c9ee01135e11e9676e9a62dfad791a7a467738f06726872374a83db49"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:636a15acc588f70fda1661234761f9ed9ad79ebed3f2125d44be0862708b666e"}, + {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, +] + +[[package]] +name = "safetensors" +version = "0.4.3" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "safetensors-0.4.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:dcf5705cab159ce0130cd56057f5f3425023c407e170bca60b4868048bae64fd"}, + {file = "safetensors-0.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bb4f8c5d0358a31e9a08daeebb68f5e161cdd4018855426d3f0c23bb51087055"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70a5319ef409e7f88686a46607cbc3c428271069d8b770076feaf913664a07ac"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb9c65bd82f9ef3ce4970dc19ee86be5f6f93d032159acf35e663c6bea02b237"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:edb5698a7bc282089f64c96c477846950358a46ede85a1c040e0230344fdde10"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:efcc860be094b8d19ac61b452ec635c7acb9afa77beb218b1d7784c6d41fe8ad"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d88b33980222085dd6001ae2cad87c6068e0991d4f5ccf44975d216db3b57376"}, + {file = "safetensors-0.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5fc6775529fb9f0ce2266edd3e5d3f10aab068e49f765e11f6f2a63b5367021d"}, + {file = "safetensors-0.4.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9c6ad011c1b4e3acff058d6b090f1da8e55a332fbf84695cf3100c649cc452d1"}, + {file = "safetensors-0.4.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c496c5401c1b9c46d41a7688e8ff5b0310a3b9bae31ce0f0ae870e1ea2b8caf"}, + {file = "safetensors-0.4.3-cp310-none-win32.whl", hash = "sha256:38e2a8666178224a51cca61d3cb4c88704f696eac8f72a49a598a93bbd8a4af9"}, + {file = "safetensors-0.4.3-cp310-none-win_amd64.whl", hash = "sha256:393e6e391467d1b2b829c77e47d726f3b9b93630e6a045b1d1fca67dc78bf632"}, + {file = "safetensors-0.4.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:22f3b5d65e440cec0de8edaa672efa888030802e11c09b3d6203bff60ebff05a"}, + {file = "safetensors-0.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c4fa560ebd4522adddb71dcd25d09bf211b5634003f015a4b815b7647d62ebe"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9afd5358719f1b2cf425fad638fc3c887997d6782da317096877e5b15b2ce93"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d8c5093206ef4b198600ae484230402af6713dab1bd5b8e231905d754022bec7"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0b2104df1579d6ba9052c0ae0e3137c9698b2d85b0645507e6fd1813b70931a"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8cf18888606dad030455d18f6c381720e57fc6a4170ee1966adb7ebc98d4d6a3"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bf4f9d6323d9f86eef5567eabd88f070691cf031d4c0df27a40d3b4aaee755b"}, + {file = "safetensors-0.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:585c9ae13a205807b63bef8a37994f30c917ff800ab8a1ca9c9b5d73024f97ee"}, + {file = "safetensors-0.4.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faefeb3b81bdfb4e5a55b9bbdf3d8d8753f65506e1d67d03f5c851a6c87150e9"}, + {file = "safetensors-0.4.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:befdf0167ad626f22f6aac6163477fcefa342224a22f11fdd05abb3995c1783c"}, + {file = "safetensors-0.4.3-cp311-none-win32.whl", hash = "sha256:a7cef55929dcbef24af3eb40bedec35d82c3c2fa46338bb13ecf3c5720af8a61"}, + {file = "safetensors-0.4.3-cp311-none-win_amd64.whl", hash = "sha256:840b7ac0eff5633e1d053cc9db12fdf56b566e9403b4950b2dc85393d9b88d67"}, + {file = "safetensors-0.4.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:22d21760dc6ebae42e9c058d75aa9907d9f35e38f896e3c69ba0e7b213033856"}, + {file = "safetensors-0.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d22c1a10dff3f64d0d68abb8298a3fd88ccff79f408a3e15b3e7f637ef5c980"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1648568667f820b8c48317c7006221dc40aced1869908c187f493838a1362bc"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:446e9fe52c051aeab12aac63d1017e0f68a02a92a027b901c4f8e931b24e5397"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fef5d70683643618244a4f5221053567ca3e77c2531e42ad48ae05fae909f542"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a1f4430cc0c9d6afa01214a4b3919d0a029637df8e09675ceef1ca3f0dfa0df"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d603846a8585b9432a0fd415db1d4c57c0f860eb4aea21f92559ff9902bae4d"}, + {file = "safetensors-0.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a844cdb5d7cbc22f5f16c7e2a0271170750763c4db08381b7f696dbd2c78a361"}, + {file = "safetensors-0.4.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:88887f69f7a00cf02b954cdc3034ffb383b2303bc0ab481d4716e2da51ddc10e"}, + {file = "safetensors-0.4.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ee463219d9ec6c2be1d331ab13a8e0cd50d2f32240a81d498266d77d07b7e71e"}, + {file = "safetensors-0.4.3-cp312-none-win32.whl", hash = "sha256:d0dd4a1db09db2dba0f94d15addc7e7cd3a7b0d393aa4c7518c39ae7374623c3"}, + {file = "safetensors-0.4.3-cp312-none-win_amd64.whl", hash = "sha256:d14d30c25897b2bf19b6fb5ff7e26cc40006ad53fd4a88244fdf26517d852dd7"}, + {file = "safetensors-0.4.3-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d1456f814655b224d4bf6e7915c51ce74e389b413be791203092b7ff78c936dd"}, + {file = "safetensors-0.4.3-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:455d538aa1aae4a8b279344a08136d3f16334247907b18a5c3c7fa88ef0d3c46"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf476bca34e1340ee3294ef13e2c625833f83d096cfdf69a5342475602004f95"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02ef3a24face643456020536591fbd3c717c5abaa2737ec428ccbbc86dffa7a4"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7de32d0d34b6623bb56ca278f90db081f85fb9c5d327e3c18fd23ac64f465768"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a0deb16a1d3ea90c244ceb42d2c6c276059616be21a19ac7101aa97da448faf"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c59d51f182c729f47e841510b70b967b0752039f79f1de23bcdd86462a9b09ee"}, + {file = "safetensors-0.4.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1f598b713cc1a4eb31d3b3203557ac308acf21c8f41104cdd74bf640c6e538e3"}, + {file = "safetensors-0.4.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5757e4688f20df083e233b47de43845d1adb7e17b6cf7da5f8444416fc53828d"}, + {file = "safetensors-0.4.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fe746d03ed8d193674a26105e4f0fe6c726f5bb602ffc695b409eaf02f04763d"}, + {file = "safetensors-0.4.3-cp37-none-win32.whl", hash = "sha256:0d5ffc6a80f715c30af253e0e288ad1cd97a3d0086c9c87995e5093ebc075e50"}, + {file = "safetensors-0.4.3-cp37-none-win_amd64.whl", hash = "sha256:a11c374eb63a9c16c5ed146457241182f310902bd2a9c18255781bb832b6748b"}, + {file = "safetensors-0.4.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:b1e31be7945f66be23f4ec1682bb47faa3df34cb89fc68527de6554d3c4258a4"}, + {file = "safetensors-0.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:03a4447c784917c9bf01d8f2ac5080bc15c41692202cd5f406afba16629e84d6"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d244bcafeb1bc06d47cfee71727e775bca88a8efda77a13e7306aae3813fa7e4"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53c4879b9c6bd7cd25d114ee0ef95420e2812e676314300624594940a8d6a91f"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74707624b81f1b7f2b93f5619d4a9f00934d5948005a03f2c1845ffbfff42212"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d52c958dc210265157573f81d34adf54e255bc2b59ded6218500c9b15a750eb"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f9568f380f513a60139971169c4a358b8731509cc19112369902eddb33faa4d"}, + {file = "safetensors-0.4.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0d9cd8e1560dfc514b6d7859247dc6a86ad2f83151a62c577428d5102d872721"}, + {file = "safetensors-0.4.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:89f9f17b0dacb913ed87d57afbc8aad85ea42c1085bd5de2f20d83d13e9fc4b2"}, + {file = "safetensors-0.4.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1139eb436fd201c133d03c81209d39ac57e129f5e74e34bb9ab60f8d9b726270"}, + {file = "safetensors-0.4.3-cp38-none-win32.whl", hash = "sha256:d9c289f140a9ae4853fc2236a2ffc9a9f2d5eae0cb673167e0f1b8c18c0961ac"}, + {file = "safetensors-0.4.3-cp38-none-win_amd64.whl", hash = "sha256:622afd28968ef3e9786562d352659a37de4481a4070f4ebac883f98c5836563e"}, + {file = "safetensors-0.4.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:8651c7299cbd8b4161a36cd6a322fa07d39cd23535b144d02f1c1972d0c62f3c"}, + {file = "safetensors-0.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e375d975159ac534c7161269de24ddcd490df2157b55c1a6eeace6cbb56903f0"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:084fc436e317f83f7071fc6a62ca1c513b2103db325cd09952914b50f51cf78f"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:41a727a7f5e6ad9f1db6951adee21bbdadc632363d79dc434876369a17de6ad6"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7dbbde64b6c534548696808a0e01276d28ea5773bc9a2dfb97a88cd3dffe3df"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bbae3b4b9d997971431c346edbfe6e41e98424a097860ee872721e176040a893"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01e4b22e3284cd866edeabe4f4d896229495da457229408d2e1e4810c5187121"}, + {file = "safetensors-0.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dd37306546b58d3043eb044c8103a02792cc024b51d1dd16bd3dd1f334cb3ed"}, + {file = "safetensors-0.4.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d8815b5e1dac85fc534a97fd339e12404db557878c090f90442247e87c8aeaea"}, + {file = "safetensors-0.4.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e011cc162503c19f4b1fd63dfcddf73739c7a243a17dac09b78e57a00983ab35"}, + {file = "safetensors-0.4.3-cp39-none-win32.whl", hash = "sha256:01feb3089e5932d7e662eda77c3ecc389f97c0883c4a12b5cfdc32b589a811c3"}, + {file = "safetensors-0.4.3-cp39-none-win_amd64.whl", hash = "sha256:3f9cdca09052f585e62328c1c2923c70f46814715c795be65f0b93f57ec98a02"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1b89381517891a7bb7d1405d828b2bf5d75528299f8231e9346b8eba092227f9"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:cd6fff9e56df398abc5866b19a32124815b656613c1c5ec0f9350906fd798aac"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:840caf38d86aa7014fe37ade5d0d84e23dcfbc798b8078015831996ecbc206a3"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9650713b2cfa9537a2baf7dd9fee458b24a0aaaa6cafcea8bdd5fb2b8efdc34"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e4119532cd10dba04b423e0f86aecb96cfa5a602238c0aa012f70c3a40c44b50"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e066e8861eef6387b7c772344d1fe1f9a72800e04ee9a54239d460c400c72aab"}, + {file = "safetensors-0.4.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:90964917f5b0fa0fa07e9a051fbef100250c04d150b7026ccbf87a34a54012e0"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c41e1893d1206aa7054029681778d9a58b3529d4c807002c156d58426c225173"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae7613a119a71a497d012ccc83775c308b9c1dab454806291427f84397d852fd"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9bac020faba7f5dc481e881b14b6425265feabb5bfc552551d21189c0eddc3"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:420a98f593ff9930f5822560d14c395ccbc57342ddff3b463bc0b3d6b1951550"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f5e6883af9a68c0028f70a4c19d5a6ab6238a379be36ad300a22318316c00cb0"}, + {file = "safetensors-0.4.3-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:cdd0a3b5da66e7f377474599814dbf5cbf135ff059cc73694de129b58a5e8a2c"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9bfb92f82574d9e58401d79c70c716985dc049b635fef6eecbb024c79b2c46ad"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:3615a96dd2dcc30eb66d82bc76cda2565f4f7bfa89fcb0e31ba3cea8a1a9ecbb"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:868ad1b6fc41209ab6bd12f63923e8baeb1a086814cb2e81a65ed3d497e0cf8f"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7ffba80aa49bd09195145a7fd233a7781173b422eeb995096f2b30591639517"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c0acbe31340ab150423347e5b9cc595867d814244ac14218932a5cf1dd38eb39"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19bbdf95de2cf64f25cd614c5236c8b06eb2cfa47cbf64311f4b5d80224623a3"}, + {file = "safetensors-0.4.3-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b852e47eb08475c2c1bd8131207b405793bfc20d6f45aff893d3baaad449ed14"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5d07cbca5b99babb692d76d8151bec46f461f8ad8daafbfd96b2fca40cadae65"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1ab6527a20586d94291c96e00a668fa03f86189b8a9defa2cdd34a1a01acc7d5"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02318f01e332cc23ffb4f6716e05a492c5f18b1d13e343c49265149396284a44"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec4b52ce9a396260eb9731eb6aea41a7320de22ed73a1042c2230af0212758ce"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:018b691383026a2436a22b648873ed11444a364324e7088b99cd2503dd828400"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:309b10dbcab63269ecbf0e2ca10ce59223bb756ca5d431ce9c9eeabd446569da"}, + {file = "safetensors-0.4.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b277482120df46e27a58082df06a15aebda4481e30a1c21eefd0921ae7e03f65"}, + {file = "safetensors-0.4.3.tar.gz", hash = "sha256:2f85fc50c4e07a21e95c24e07460fe6f7e2859d0ce88092838352b798ce711c2"}, +] + +[package.extras] +all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"] +dev = ["safetensors[all]"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"] +mlx = ["mlx (>=0.0.9)"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"] +pinned-tf = ["safetensors[numpy]", "tensorflow (==2.11.0)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] +torch = ["safetensors[numpy]", "torch (>=1.10)"] + +[[package]] +name = "scipy" +version = "1.13.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, + {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, + {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, + {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, + {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, + {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, + {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, + {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, + {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, + {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, + {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, +] + +[package.dependencies] +numpy = ">=1.22.4,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + +[[package]] +name = "sentencepiece" +version = "0.1.99" +description = "SentencePiece python wrapper" +optional = false +python-versions = "*" +files = [ + {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0eb528e70571b7c02723e5804322469b82fe7ea418c96051d0286c0fa028db73"}, + {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d7fafb2c4e4659cbdf303929503f37a26eabc4ff31d3a79bf1c5a1b338caa7"}, + {file = "sentencepiece-0.1.99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be9cf5b9e404c245aeb3d3723c737ba7a8f5d4ba262ef233a431fa6c45f732a0"}, + {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baed1a26464998f9710d20e52607c29ffd4293e7c71c6a1f83f51ad0911ec12c"}, + {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9832f08bb372d4c8b567612f8eab9e36e268dff645f1c28f9f8e851be705f6d1"}, + {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:019e7535108e309dae2b253a75834fc3128240aa87c00eb80732078cdc182588"}, + {file = "sentencepiece-0.1.99-cp310-cp310-win32.whl", hash = "sha256:fa16a830416bb823fa2a52cbdd474d1f7f3bba527fd2304fb4b140dad31bb9bc"}, + {file = "sentencepiece-0.1.99-cp310-cp310-win_amd64.whl", hash = "sha256:14b0eccb7b641d4591c3e12ae44cab537d68352e4d3b6424944f0c447d2348d5"}, + {file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6d3c56f24183a1e8bd61043ff2c58dfecdc68a5dd8955dc13bab83afd5f76b81"}, + {file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed6ea1819fd612c989999e44a51bf556d0ef6abfb553080b9be3d347e18bcfb7"}, + {file = "sentencepiece-0.1.99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2a0260cd1fb7bd8b4d4f39dc2444a8d5fd4e0a0c4d5c899810ef1abf99b2d45"}, + {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a1abff4d1ff81c77cac3cc6fefa34fa4b8b371e5ee51cb7e8d1ebc996d05983"}, + {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:004e6a621d4bc88978eecb6ea7959264239a17b70f2cbc348033d8195c9808ec"}, + {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db361e03342c41680afae5807590bc88aa0e17cfd1a42696a160e4005fcda03b"}, + {file = "sentencepiece-0.1.99-cp311-cp311-win32.whl", hash = "sha256:2d95e19168875b70df62916eb55428a0cbcb834ac51d5a7e664eda74def9e1e0"}, + {file = "sentencepiece-0.1.99-cp311-cp311-win_amd64.whl", hash = "sha256:f90d73a6f81248a909f55d8e6ef56fec32d559e1e9af045f0b0322637cb8e5c7"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:62e24c81e74bd87a6e0d63c51beb6527e4c0add67e1a17bac18bcd2076afcfeb"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57efcc2d51caff20d9573567d9fd3f854d9efe613ed58a439c78c9f93101384a"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6a904c46197993bd1e95b93a6e373dca2f170379d64441041e2e628ad4afb16f"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d89adf59854741c0d465f0e1525b388c0d174f611cc04af54153c5c4f36088c4"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-win32.whl", hash = "sha256:47c378146928690d1bc106fdf0da768cebd03b65dd8405aa3dd88f9c81e35dba"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-win_amd64.whl", hash = "sha256:9ba142e7a90dd6d823c44f9870abdad45e6c63958eb60fe44cca6828d3b69da2"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b7b1a9ae4d7c6f1f867e63370cca25cc17b6f4886729595b885ee07a58d3cec3"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0f644c9d4d35c096a538507b2163e6191512460035bf51358794a78515b74f7"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8843d23a0f686d85e569bd6dcd0dd0e0cbc03731e63497ca6d5bacd18df8b85"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e6f690a1caebb4867a2e367afa1918ad35be257ecdb3455d2bbd787936f155"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-win32.whl", hash = "sha256:8a321866c2f85da7beac74a824b4ad6ddc2a4c9bccd9382529506d48f744a12c"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-win_amd64.whl", hash = "sha256:c42f753bcfb7661c122a15b20be7f684b61fc8592c89c870adf52382ea72262d"}, + {file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:85b476406da69c70586f0bb682fcca4c9b40e5059814f2db92303ea4585c650c"}, + {file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cfbcfe13c69d3f87b7fcd5da168df7290a6d006329be71f90ba4f56bc77f8561"}, + {file = "sentencepiece-0.1.99-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:445b0ec381af1cd4eef95243e7180c63d9c384443c16c4c47a28196bd1cda937"}, + {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6890ea0f2b4703f62d0bf27932e35808b1f679bdb05c7eeb3812b935ba02001"}, + {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb71af492b0eefbf9f2501bec97bcd043b6812ab000d119eaf4bd33f9e283d03"}, + {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27b866b5bd3ddd54166bbcbf5c8d7dd2e0b397fac8537991c7f544220b1f67bc"}, + {file = "sentencepiece-0.1.99-cp38-cp38-win32.whl", hash = "sha256:b133e8a499eac49c581c3c76e9bdd08c338cc1939e441fee6f92c0ccb5f1f8be"}, + {file = "sentencepiece-0.1.99-cp38-cp38-win_amd64.whl", hash = "sha256:0eaf3591dd0690a87f44f4df129cf8d05d8a4029b5b6709b489b8e27f9a9bcff"}, + {file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38efeda9bbfb55052d482a009c6a37e52f42ebffcea9d3a98a61de7aee356a28"}, + {file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c030b081dc1e1bcc9fadc314b19b740715d3d566ad73a482da20d7d46fd444c"}, + {file = "sentencepiece-0.1.99-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:84dbe53e02e4f8a2e45d2ac3e430d5c83182142658e25edd76539b7648928727"}, + {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b0f55d0a0ee1719b4b04221fe0c9f0c3461dc3dabd77a035fa2f4788eb3ef9a"}, + {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e800f206cd235dc27dc749299e05853a4e4332e8d3dfd81bf13d0e5b9007d9"}, + {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ae1c40cda8f9d5b0423cfa98542735c0235e7597d79caf318855cdf971b2280"}, + {file = "sentencepiece-0.1.99-cp39-cp39-win32.whl", hash = "sha256:c84ce33af12ca222d14a1cdd37bd76a69401e32bc68fe61c67ef6b59402f4ab8"}, + {file = "sentencepiece-0.1.99-cp39-cp39-win_amd64.whl", hash = "sha256:350e5c74d739973f1c9643edb80f7cc904dc948578bcb1d43c6f2b173e5d18dd"}, + {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, +] + +[[package]] +name = "setuptools" +version = "70.0.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, + {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "sympy" +version = "1.12.1" +description = "Computer algebra system (CAS) in Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, + {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, +] + +[package.dependencies] +mpmath = ">=1.1.0,<1.4.0" + +[[package]] +name = "tbb" +version = "2021.12.0" +description = "Intel® oneAPI Threading Building Blocks (oneTBB)" +optional = true +python-versions = "*" +files = [ + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, + {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, + {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, + {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, +] + +[[package]] +name = "texttable" +version = "1.7.0" +description = "module to create simple ASCII tables" +optional = true +python-versions = "*" +files = [ + {file = "texttable-1.7.0-py2.py3-none-any.whl", hash = "sha256:72227d592c82b3d7f672731ae73e4d1f88cd8e2ef5b075a7a7f01a23a3743917"}, + {file = "texttable-1.7.0.tar.gz", hash = "sha256:2d2068fb55115807d3ac77a4ca68fa48803e84ebb0ee2340f858107a36522638"}, +] + +[[package]] +name = "tokenizers" +version = "0.19.1" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97"}, + {file = "tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b07c538ba956843833fee1190cf769c60dc62e1cf934ed50d77d5502194d63b1"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28cab1582e0eec38b1f38c1c1fb2e56bce5dc180acb1724574fc5f47da2a4fe"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e"}, + {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98"}, + {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3"}, + {file = "tokenizers-0.19.1-cp310-none-win32.whl", hash = "sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837"}, + {file = "tokenizers-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403"}, + {file = "tokenizers-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059"}, + {file = "tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76951121890fea8330d3a0df9a954b3f2a37e3ec20e5b0530e9a0044ca2e11fe"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b342d2ce8fc8d00f376af068e3274e2e8649562e3bc6ae4a67784ded6b99428d"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa"}, + {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6"}, + {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b"}, + {file = "tokenizers-0.19.1-cp311-none-win32.whl", hash = "sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256"}, + {file = "tokenizers-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66"}, + {file = "tokenizers-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153"}, + {file = "tokenizers-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5c2ff13d157afe413bf7e25789879dd463e5a4abfb529a2d8f8473d8042e28f"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3174c76efd9d08f836bfccaca7cfec3f4d1c0a4cf3acbc7236ad577cc423c840"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3"}, + {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea"}, + {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c"}, + {file = "tokenizers-0.19.1-cp312-none-win32.whl", hash = "sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57"}, + {file = "tokenizers-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a"}, + {file = "tokenizers-0.19.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:bb9dfe7dae85bc6119d705a76dc068c062b8b575abe3595e3c6276480e67e3f1"}, + {file = "tokenizers-0.19.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:1f0360cbea28ea99944ac089c00de7b2e3e1c58f479fb8613b6d8d511ce98267"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:71e3ec71f0e78780851fef28c2a9babe20270404c921b756d7c532d280349214"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b82931fa619dbad979c0ee8e54dd5278acc418209cc897e42fac041f5366d626"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e8ff5b90eabdcdaa19af697885f70fe0b714ce16709cf43d4952f1f85299e73a"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e742d76ad84acbdb1a8e4694f915fe59ff6edc381c97d6dfdd054954e3478ad4"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d8c5d59d7b59885eab559d5bc082b2985555a54cda04dda4c65528d90ad252ad"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b2da5c32ed869bebd990c9420df49813709e953674c0722ff471a116d97b22d"}, + {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:638e43936cc8b2cbb9f9d8dde0fe5e7e30766a3318d2342999ae27f68fdc9bd6"}, + {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:78e769eb3b2c79687d9cb0f89ef77223e8e279b75c0a968e637ca7043a84463f"}, + {file = "tokenizers-0.19.1-cp37-none-win32.whl", hash = "sha256:72791f9bb1ca78e3ae525d4782e85272c63faaef9940d92142aa3eb79f3407a3"}, + {file = "tokenizers-0.19.1-cp37-none-win_amd64.whl", hash = "sha256:f3bbb7a0c5fcb692950b041ae11067ac54826204318922da754f908d95619fbc"}, + {file = "tokenizers-0.19.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:07f9295349bbbcedae8cefdbcfa7f686aa420be8aca5d4f7d1ae6016c128c0c5"}, + {file = "tokenizers-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10a707cc6c4b6b183ec5dbfc5c34f3064e18cf62b4a938cb41699e33a99e03c1"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6309271f57b397aa0aff0cbbe632ca9d70430839ca3178bf0f06f825924eca22"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad23d37d68cf00d54af184586d79b84075ada495e7c5c0f601f051b162112dc"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:427c4f0f3df9109314d4f75b8d1f65d9477033e67ffaec4bca53293d3aca286d"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e83a31c9cf181a0a3ef0abad2b5f6b43399faf5da7e696196ddd110d332519ee"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c27b99889bd58b7e301468c0838c5ed75e60c66df0d4db80c08f43462f82e0d3"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bac0b0eb952412b0b196ca7a40e7dce4ed6f6926489313414010f2e6b9ec2adf"}, + {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8a6298bde623725ca31c9035a04bf2ef63208d266acd2bed8c2cb7d2b7d53ce6"}, + {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:08a44864e42fa6d7d76d7be4bec62c9982f6f6248b4aa42f7302aa01e0abfd26"}, + {file = "tokenizers-0.19.1-cp38-none-win32.whl", hash = "sha256:1de5bc8652252d9357a666e609cb1453d4f8e160eb1fb2830ee369dd658e8975"}, + {file = "tokenizers-0.19.1-cp38-none-win_amd64.whl", hash = "sha256:0bcce02bf1ad9882345b34d5bd25ed4949a480cf0e656bbd468f4d8986f7a3f1"}, + {file = "tokenizers-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0b9394bd204842a2a1fd37fe29935353742be4a3460b6ccbaefa93f58a8df43d"}, + {file = "tokenizers-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4692ab92f91b87769d950ca14dbb61f8a9ef36a62f94bad6c82cc84a51f76f6a"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6258c2ef6f06259f70a682491c78561d492e885adeaf9f64f5389f78aa49a051"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c85cf76561fbd01e0d9ea2d1cbe711a65400092bc52b5242b16cfd22e51f0c58"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670b802d4d82bbbb832ddb0d41df7015b3e549714c0e77f9bed3e74d42400fbe"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85aa3ab4b03d5e99fdd31660872249df5e855334b6c333e0bc13032ff4469c4a"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbf001afbbed111a79ca47d75941e9e5361297a87d186cbfc11ed45e30b5daba"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c89aa46c269e4e70c4d4f9d6bc644fcc39bb409cb2a81227923404dd6f5227"}, + {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:39c1ec76ea1027438fafe16ecb0fb84795e62e9d643444c1090179e63808c69d"}, + {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c2a0d47a89b48d7daa241e004e71fb5a50533718897a4cd6235cb846d511a478"}, + {file = "tokenizers-0.19.1-cp39-none-win32.whl", hash = "sha256:61b7fe8886f2e104d4caf9218b157b106207e0f2a4905c9c7ac98890688aabeb"}, + {file = "tokenizers-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:f97660f6c43efd3e0bfd3f2e3e5615bf215680bad6ee3d469df6454b8c6e8256"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b4399b59d1af5645bcee2072a463318114c39b8547437a7c2d6a186a1b5a0e2d"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6852c5b2a853b8b0ddc5993cd4f33bfffdca4fcc5d52f89dd4b8eada99379285"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcd266ae85c3d39df2f7e7d0e07f6c41a55e9a3123bb11f854412952deacd828"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecb2651956eea2aa0a2d099434134b1b68f1c31f9a5084d6d53f08ed43d45ff2"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:b279ab506ec4445166ac476fb4d3cc383accde1ea152998509a94d82547c8e2a"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:89183e55fb86e61d848ff83753f64cded119f5d6e1f553d14ffee3700d0a4a49"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2edbc75744235eea94d595a8b70fe279dd42f3296f76d5a86dde1d46e35f574"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0e64bfde9a723274e9a71630c3e9494ed7b4c0f76a1faacf7fe294cd26f7ae7c"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0b5ca92bfa717759c052e345770792d02d1f43b06f9e790ca0a1db62838816f3"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f8a20266e695ec9d7a946a019c1d5ca4eddb6613d4f466888eee04f16eedb85"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63c38f45d8f2a2ec0f3a20073cccb335b9f99f73b3c69483cd52ebc75369d8a1"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dd26e3afe8a7b61422df3176e06664503d3f5973b94f45d5c45987e1cb711876"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:eddd5783a4a6309ce23432353cdb36220e25cbb779bfa9122320666508b44b88"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ed69af290c2b65169f0ba9034d1dc39a5db9459b32f1dd8b5f3f32a3fcf06eab"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f8a9c828277133af13f3859d1b6bf1c3cb6e9e1637df0e45312e6b7c2e622b1f"}, + {file = "tokenizers-0.19.1.tar.gz", hash = "sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3"}, +] + +[package.dependencies] +huggingface-hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "torch" +version = "2.3.0" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, + {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, + {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, + {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, + {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, + {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, + {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, + {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, + {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, + {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, + {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, + {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, + {file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, + {file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, + {file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, + {file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, + {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, + {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, + {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, + {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +jinja2 = "*" +mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} +networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +sympy = "*" +triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +typing-extensions = ">=4.8.0" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.9.1)"] + +[[package]] +name = "tqdm" +version = "4.66.4" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"}, + {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "transformers" +version = "4.41.2" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "transformers-4.41.2-py3-none-any.whl", hash = "sha256:05555d20e43f808de1ef211ab64803cdb513170cef70d29a888b589caebefc67"}, + {file = "transformers-4.41.2.tar.gz", hash = "sha256:80a4db216533d573e9cc7388646c31ed9480918feb7c55eb211249cb23567f87"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.23.0,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.4.1" +tokenizers = ">=0.19,<0.20" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6,<0.15.0)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +timm = ["timm"] +tokenizers = ["tokenizers (>=0.19,<0.20)"] +torch = ["accelerate (>=0.21.0)", "torch"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.23.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + +[[package]] +name = "triton" +version = "2.3.0" +description = "A language and compiler for custom Deep Learning operations" +optional = true +python-versions = "*" +files = [ + {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, + {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, + {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, + {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, + {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, + {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] + +[[package]] +name = "typer" +version = "0.6.1" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.6.1-py3-none-any.whl", hash = "sha256:54b19e5df18654070a82f8c2aa1da456a4ac16a2a83e6dcd9f170e291c56338e"}, + {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + +[[package]] +name = "typing-extensions" +version = "4.12.1" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, + {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, +] + +[[package]] +name = "tzdata" +version = "2024.1" +description = "Provider of IANA time zone data" +optional = true +python-versions = ">=2" +files = [ + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, +] + +[[package]] +name = "urllib3" +version = "2.2.1" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.8" +files = [ + {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, + {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +h2 = ["h2 (>=4,<5)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + +[[package]] +name = "wrapt" +version = "1.16.0" +description = "Module for decorators, wrappers and monkey patching." +optional = false +python-versions = ">=3.6" +files = [ + {file = "wrapt-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4"}, + {file = "wrapt-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136"}, + {file = "wrapt-1.16.0-cp310-cp310-win32.whl", hash = "sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d"}, + {file = "wrapt-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2"}, + {file = "wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09"}, + {file = "wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d"}, + {file = "wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362"}, + {file = "wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89"}, + {file = "wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b"}, + {file = "wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c"}, + {file = "wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc"}, + {file = "wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8"}, + {file = "wrapt-1.16.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465"}, + {file = "wrapt-1.16.0-cp36-cp36m-win32.whl", hash = "sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e"}, + {file = "wrapt-1.16.0-cp36-cp36m-win_amd64.whl", hash = "sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966"}, + {file = "wrapt-1.16.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c"}, + {file = "wrapt-1.16.0-cp37-cp37m-win32.whl", hash = "sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c"}, + {file = "wrapt-1.16.0-cp37-cp37m-win_amd64.whl", hash = "sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00"}, + {file = "wrapt-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0"}, + {file = "wrapt-1.16.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6"}, + {file = "wrapt-1.16.0-cp38-cp38-win32.whl", hash = "sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b"}, + {file = "wrapt-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41"}, + {file = "wrapt-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2"}, + {file = "wrapt-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537"}, + {file = "wrapt-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3"}, + {file = "wrapt-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35"}, + {file = "wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1"}, + {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, +] + +[[package]] +name = "xxhash" +version = "3.4.1" +description = "Python binding for xxHash" +optional = true +python-versions = ">=3.7" +files = [ + {file = "xxhash-3.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:91dbfa55346ad3e18e738742236554531a621042e419b70ad8f3c1d9c7a16e7f"}, + {file = "xxhash-3.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:665a65c2a48a72068fcc4d21721510df5f51f1142541c890491afc80451636d2"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb11628470a6004dc71a09fe90c2f459ff03d611376c1debeec2d648f44cb693"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5bef2a7dc7b4f4beb45a1edbba9b9194c60a43a89598a87f1a0226d183764189"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c0f7b2d547d72c7eda7aa817acf8791f0146b12b9eba1d4432c531fb0352228"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00f2fdef6b41c9db3d2fc0e7f94cb3db86693e5c45d6de09625caad9a469635b"}, + {file = "xxhash-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23cfd9ca09acaf07a43e5a695143d9a21bf00f5b49b15c07d5388cadf1f9ce11"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6a9ff50a3cf88355ca4731682c168049af1ca222d1d2925ef7119c1a78e95b3b"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f1d7c69a1e9ca5faa75546fdd267f214f63f52f12692f9b3a2f6467c9e67d5e7"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:672b273040d5d5a6864a36287f3514efcd1d4b1b6a7480f294c4b1d1ee1b8de0"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4178f78d70e88f1c4a89ff1ffe9f43147185930bb962ee3979dba15f2b1cc799"}, + {file = "xxhash-3.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9804b9eb254d4b8cc83ab5a2002128f7d631dd427aa873c8727dba7f1f0d1c2b"}, + {file = "xxhash-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c09c49473212d9c87261d22c74370457cfff5db2ddfc7fd1e35c80c31a8c14ce"}, + {file = "xxhash-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:ebbb1616435b4a194ce3466d7247df23499475c7ed4eb2681a1fa42ff766aff6"}, + {file = "xxhash-3.4.1-cp310-cp310-win_arm64.whl", hash = "sha256:25dc66be3db54f8a2d136f695b00cfe88018e59ccff0f3b8f545869f376a8a46"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58c49083801885273e262c0f5bbeac23e520564b8357fbb18fb94ff09d3d3ea5"}, + {file = "xxhash-3.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b526015a973bfbe81e804a586b703f163861da36d186627e27524f5427b0d520"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36ad4457644c91a966f6fe137d7467636bdc51a6ce10a1d04f365c70d6a16d7e"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:248d3e83d119770f96003271fe41e049dd4ae52da2feb8f832b7a20e791d2920"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2070b6d5bbef5ee031666cf21d4953c16e92c2f8a24a94b5c240f8995ba3b1d0"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2746035f518f0410915e247877f7df43ef3372bf36cfa52cc4bc33e85242641"}, + {file = "xxhash-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ba6181514681c2591840d5632fcf7356ab287d4aff1c8dea20f3c78097088"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aac5010869240e95f740de43cd6a05eae180c59edd182ad93bf12ee289484fa"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4cb11d8debab1626181633d184b2372aaa09825bde709bf927704ed72765bed1"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b29728cff2c12f3d9f1d940528ee83918d803c0567866e062683f300d1d2eff3"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a15cbf3a9c40672523bdb6ea97ff74b443406ba0ab9bca10ceccd9546414bd84"}, + {file = "xxhash-3.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e66df260fed01ed8ea790c2913271641c58481e807790d9fca8bfd5a3c13844"}, + {file = "xxhash-3.4.1-cp311-cp311-win32.whl", hash = "sha256:e867f68a8f381ea12858e6d67378c05359d3a53a888913b5f7d35fbf68939d5f"}, + {file = "xxhash-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:200a5a3ad9c7c0c02ed1484a1d838b63edcf92ff538770ea07456a3732c577f4"}, + {file = "xxhash-3.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:1d03f1c0d16d24ea032e99f61c552cb2b77d502e545187338bea461fde253583"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c4bbba9b182697a52bc0c9f8ec0ba1acb914b4937cd4a877ad78a3b3eeabefb3"}, + {file = "xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9fd28a9da300e64e434cfc96567a8387d9a96e824a9be1452a1e7248b7763b78"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6066d88c9329ab230e18998daec53d819daeee99d003955c8db6fc4971b45ca3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93805bc3233ad89abf51772f2ed3355097a5dc74e6080de19706fc447da99cd3"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64da57d5ed586ebb2ecdde1e997fa37c27fe32fe61a656b77fabbc58e6fbff6e"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a97322e9a7440bf3c9805cbaac090358b43f650516486746f7fa482672593df"}, + {file = "xxhash-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbe750d512982ee7d831838a5dee9e9848f3fb440e4734cca3f298228cc957a6"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fd79d4087727daf4d5b8afe594b37d611ab95dc8e29fe1a7517320794837eb7d"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:743612da4071ff9aa4d055f3f111ae5247342931dedb955268954ef7201a71ff"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b41edaf05734092f24f48c0958b3c6cbaaa5b7e024880692078c6b1f8247e2fc"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:a90356ead70d715fe64c30cd0969072de1860e56b78adf7c69d954b43e29d9fa"}, + {file = "xxhash-3.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac56eebb364e44c85e1d9e9cc5f6031d78a34f0092fea7fc80478139369a8b4a"}, + {file = "xxhash-3.4.1-cp312-cp312-win32.whl", hash = "sha256:911035345932a153c427107397c1518f8ce456f93c618dd1c5b54ebb22e73747"}, + {file = "xxhash-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:f31ce76489f8601cc7b8713201ce94b4bd7b7ce90ba3353dccce7e9e1fee71fa"}, + {file = "xxhash-3.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:b5beb1c6a72fdc7584102f42c4d9df232ee018ddf806e8c90906547dfb43b2da"}, + {file = "xxhash-3.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6d42b24d1496deb05dee5a24ed510b16de1d6c866c626c2beb11aebf3be278b9"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b685fab18876b14a8f94813fa2ca80cfb5ab6a85d31d5539b7cd749ce9e3624"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:419ffe34c17ae2df019a4685e8d3934d46b2e0bbe46221ab40b7e04ed9f11137"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e041ce5714f95251a88670c114b748bca3bf80cc72400e9f23e6d0d59cf2681"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc860d887c5cb2f524899fb8338e1bb3d5789f75fac179101920d9afddef284b"}, + {file = "xxhash-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:312eba88ffe0a05e332e3a6f9788b73883752be63f8588a6dc1261a3eaaaf2b2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e01226b6b6a1ffe4e6bd6d08cfcb3ca708b16f02eb06dd44f3c6e53285f03e4f"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9f3025a0d5d8cf406a9313cd0d5789c77433ba2004b1c75439b67678e5136537"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:6d3472fd4afef2a567d5f14411d94060099901cd8ce9788b22b8c6f13c606a93"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:43984c0a92f06cac434ad181f329a1445017c33807b7ae4f033878d860a4b0f2"}, + {file = "xxhash-3.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a55e0506fdb09640a82ec4f44171273eeabf6f371a4ec605633adb2837b5d9d5"}, + {file = "xxhash-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:faec30437919555b039a8bdbaba49c013043e8f76c999670aef146d33e05b3a0"}, + {file = "xxhash-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:c9e1b646af61f1fc7083bb7b40536be944f1ac67ef5e360bca2d73430186971a"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:961d948b7b1c1b6c08484bbce3d489cdf153e4122c3dfb07c2039621243d8795"}, + {file = "xxhash-3.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:719a378930504ab159f7b8e20fa2aa1896cde050011af838af7e7e3518dd82de"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74fb5cb9406ccd7c4dd917f16630d2e5e8cbbb02fc2fca4e559b2a47a64f4940"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dab508ac39e0ab988039bc7f962c6ad021acd81fd29145962b068df4148c476"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c59f3e46e7daf4c589e8e853d700ef6607afa037bfad32c390175da28127e8c"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc07256eff0795e0f642df74ad096f8c5d23fe66bc138b83970b50fc7f7f6c5"}, + {file = "xxhash-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9f749999ed80f3955a4af0eb18bb43993f04939350b07b8dd2f44edc98ffee9"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7688d7c02149a90a3d46d55b341ab7ad1b4a3f767be2357e211b4e893efbaaf6"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a8b4977963926f60b0d4f830941c864bed16aa151206c01ad5c531636da5708e"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:8106d88da330f6535a58a8195aa463ef5281a9aa23b04af1848ff715c4398fb4"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4c76a77dbd169450b61c06fd2d5d436189fc8ab7c1571d39265d4822da16df22"}, + {file = "xxhash-3.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:11f11357c86d83e53719c592021fd524efa9cf024dc7cb1dfb57bbbd0d8713f2"}, + {file = "xxhash-3.4.1-cp38-cp38-win32.whl", hash = "sha256:0c786a6cd74e8765c6809892a0d45886e7c3dc54de4985b4a5eb8b630f3b8e3b"}, + {file = "xxhash-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:aabf37fb8fa27430d50507deeab2ee7b1bcce89910dd10657c38e71fee835594"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6127813abc1477f3a83529b6bbcfeddc23162cece76fa69aee8f6a8a97720562"}, + {file = "xxhash-3.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef2e194262f5db16075caea7b3f7f49392242c688412f386d3c7b07c7733a70a"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71be94265b6c6590f0018bbf73759d21a41c6bda20409782d8117e76cd0dfa8b"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10e0a619cdd1c0980e25eb04e30fe96cf8f4324758fa497080af9c21a6de573f"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa122124d2e3bd36581dd78c0efa5f429f5220313479fb1072858188bc2d5ff1"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17032f5a4fea0a074717fe33477cb5ee723a5f428de7563e75af64bfc1b1e10"}, + {file = "xxhash-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca7783b20e3e4f3f52f093538895863f21d18598f9a48211ad757680c3bd006f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d77d09a1113899fad5f354a1eb4f0a9afcf58cefff51082c8ad643ff890e30cf"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:21287bcdd299fdc3328cc0fbbdeaa46838a1c05391264e51ddb38a3f5b09611f"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:dfd7a6cc483e20b4ad90224aeb589e64ec0f31e5610ab9957ff4314270b2bf31"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:543c7fcbc02bbb4840ea9915134e14dc3dc15cbd5a30873a7a5bf66039db97ec"}, + {file = "xxhash-3.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fe0a98d990e433013f41827b62be9ab43e3cf18e08b1483fcc343bda0d691182"}, + {file = "xxhash-3.4.1-cp39-cp39-win32.whl", hash = "sha256:b9097af00ebf429cc7c0e7d2fdf28384e4e2e91008130ccda8d5ae653db71e54"}, + {file = "xxhash-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:d699b921af0dcde50ab18be76c0d832f803034d80470703700cb7df0fbec2832"}, + {file = "xxhash-3.4.1-cp39-cp39-win_arm64.whl", hash = "sha256:2be491723405e15cc099ade1280133ccfbf6322d2ef568494fb7d07d280e7eee"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:431625fad7ab5649368c4849d2b49a83dc711b1f20e1f7f04955aab86cd307bc"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc6dbd5fc3c9886a9e041848508b7fb65fd82f94cc793253990f81617b61fe49"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ff8dbd0ec97aec842476cb8ccc3e17dd288cd6ce3c8ef38bff83d6eb927817"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef73a53fe90558a4096e3256752268a8bdc0322f4692ed928b6cd7ce06ad4fe3"}, + {file = "xxhash-3.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:450401f42bbd274b519d3d8dcf3c57166913381a3d2664d6609004685039f9d3"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a162840cf4de8a7cd8720ff3b4417fbc10001eefdd2d21541a8226bb5556e3bb"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b736a2a2728ba45017cb67785e03125a79d246462dfa892d023b827007412c52"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0ae4c2e7698adef58710d6e7a32ff518b66b98854b1c68e70eee504ad061d8"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6322c4291c3ff174dcd104fae41500e75dad12be6f3085d119c2c8a80956c51"}, + {file = "xxhash-3.4.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:dd59ed668801c3fae282f8f4edadf6dc7784db6d18139b584b6d9677ddde1b6b"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92693c487e39523a80474b0394645b393f0ae781d8db3474ccdcead0559ccf45"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4603a0f642a1e8d7f3ba5c4c25509aca6a9c1cc16f85091004a7028607ead663"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa45e8cbfbadb40a920fe9ca40c34b393e0b067082d94006f7f64e70c7490a6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:595b252943b3552de491ff51e5bb79660f84f033977f88f6ca1605846637b7c6"}, + {file = "xxhash-3.4.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:562d8b8f783c6af969806aaacf95b6c7b776929ae26c0cd941d54644ea7ef51e"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:41ddeae47cf2828335d8d991f2d2b03b0bdc89289dc64349d712ff8ce59d0647"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c44d584afdf3c4dbb3277e32321d1a7b01d6071c1992524b6543025fb8f4206f"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd7bddb3a5b86213cc3f2c61500c16945a1b80ecd572f3078ddbbe68f9dabdfb"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ecb6c987b62437c2f99c01e97caf8d25660bf541fe79a481d05732e5236719c"}, + {file = "xxhash-3.4.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:696b4e18b7023527d5c50ed0626ac0520edac45a50ec7cf3fc265cd08b1f4c03"}, + {file = "xxhash-3.4.1.tar.gz", hash = "sha256:0379d6cf1ff987cd421609a264ce025e74f346e3e145dd106c0cc2e3ec3f99a9"}, +] + +[[package]] +name = "yarl" +version = "1.9.4" +description = "Yet another URL library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + +[extras] +accelerate = ["accelerate"] +bnb = ["bitsandbytes"] +outlines = ["outlines"] +peft = ["peft"] +quantize = ["accelerate", "datasets", "texttable"] +torch = ["torch"] + +[metadata] +lock-version = "2.0" +python-versions = ">=3.9,<3.13" +content-hash = "06e67944a2b1cf9884a31e771d0e9d89877e9b3c91894982cb67d104cb834758" diff --git a/server/punica_kernels/__init__.py b/server/punica_kernels/__init__.py index 4d1fd627..5604448c 100644 --- a/server/punica_kernels/__init__.py +++ b/server/punica_kernels/__init__.py @@ -2,6 +2,7 @@ import punica.ops._kernels as _kernels + def bgmv( y: torch.Tensor, x: torch.Tensor, diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/axpby.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/axpby.hpp new file mode 100644 index 00000000..df9605b7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/axpby.hpp @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor && y, + PrdTensor const& p = {}) +{ + return axpby(alpha, x, beta, y, p); +} + +// +// AXPBY +// +template +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor & y, + PrdTensor const& p = {}) +{ + auto isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; + } (); + + CUTE_UNROLL + for (int i = 0; i < size(x); ++i) { + if (p(i)) { + y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i)); + } + } +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/clear.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/clear.hpp new file mode 100644 index 00000000..f738b35b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/clear.hpp @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +clear(Tensor&& tensor) +{ + return clear(tensor); +} + +// +// Set elements to zero +// +template +CUTE_HOST_DEVICE +void +clear(Tensor& tensor) +{ + using T = typename Tensor::value_type; + + fill(tensor, T{}); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/cooperative_copy.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/cooperative_copy.hpp new file mode 100644 index 00000000..c0337aba --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/cooperative_copy.hpp @@ -0,0 +1,196 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include + +#include +#include + +namespace cute +{ + +// cooperative_copy(thr_idx, src, dst) +// Use NumThreads to copy src to dst with element vectorization up to MaxVecBits. +// @pre 0 <= @a tid < NumThreads +// @pre Tensors @a src and @a dst are aligned up to MaxVecBits. +// +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst) +{ + // Assumes the shapes are static, can generalize + CUTE_STATIC_ASSERT_V(size(src) == size(dst)); + // Assumes the types are the same, can generalize + static_assert(sizeof_bits_v == sizeof_bits_v); + static_assert(MaxVecBits == sizeof_bits_v || + MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance."); + // Check that the tensors are likely shared across threads: either gmem or smem + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem source tensor."); + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem destination tensor."); + + // Precondition on tid in DEBUG + assert(tid < NumThreads); + // Precondition on pointer alignment in DEBUG + assert(is_byte_aligned(raw_pointer_cast(src.data()))); + assert(is_byte_aligned(raw_pointer_cast(dst.data()))); + // + // Determine val+thr vectorization based on src/dst size and number of threads + // NOTE: This heuristic promotes parallelization over vectorization + // + + constexpr int elem_bits = sizeof_bits_v; + + // The number of elements that can be vectorized in values + constexpr int common_elem = decltype(max_common_vector(src, dst))::value; + constexpr int common_bits = common_elem * elem_bits; + constexpr int total_elem = decltype(size(src))::value; + constexpr int total_bits = total_elem * elem_bits; + static_assert(total_bits % NumThreads == 0); + constexpr int total_bits_per_thr = total_bits / NumThreads; + // If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits + constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr); + + // Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits + constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast(MaxVecBits)); + // Convert back to number of elements, safe_div + static_assert((vec_bits % elem_bits) == 0); + constexpr int vec_elem = vec_bits / elem_bits; + + // Use only part of threads if there's not enough work for all threads + constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0) + ? NumThreads + : (total_elem / vec_elem); + + // The common layout of the two tensors that can be vectorized over threads + // vidx -> coord + auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()), + get_nonswizzle_portion(dst.layout())); + + // Scale up the common_layout to cover the entire tensors + // vidx -> coord + auto full_perm = tile_to_shape(make_layout(common_layout), size(src)); + + // Create the Tiler + // ((vid,tid),iter) + auto layout_vt = logical_divide(full_perm, Layout, Int>>{}); + + // Apply and slice + Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_); + Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_); + + // Should account for vec_bits < 8 and/or vec_elem <= 1 + // And also account for subbyte types, which could cause race conditions + // Want to ENFORCE sufficient vectorization in those cases + static_assert((vec_bits >= 8), "No support for subbyte copying"); + using VecType = uint_bit_t; + +#if 0 + if (thread0()) { + print(" "); print("NumThreads: "); print(NumThreads); print("\n"); + print(" "); print("src: "); print(src); print("\n"); + print(" "); print("dst: "); print(dst); print("\n"); + print(" "); print("common_layout: "); print(common_layout); print("\n"); + print(" "); print("full_perm: "); print(full_perm); print("\n"); + print(" "); print("Used vector: "); print(vec_elem); print("\n"); + print(" "); print("Used threads: "); print(vec_thrs); print("\n"); + print(" "); print("layout_vt: "); print(layout_vt); print("\n"); + print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n"); + print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); + print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // If we're using all threads (static) or the tid is in in-range (dynamic) + if (vec_thrs >= NumThreads or tid < vec_thrs) { + return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); + } +} + +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst) +{ + constexpr uint32_t MaxVecBits = sizeof_bits_v; + return cooperative_copy(tid, src, dst); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst) +{ + return cooperative_copy(tid, src, dst); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst) +{ + return cooperative_copy(tid, src, dst); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/cooperative_gemm.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/cooperative_gemm.hpp new file mode 100644 index 00000000..32cec54b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/cooperative_gemm.hpp @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +#include +#include + +#include + +namespace cute +{ + +// +// Collective Shared-Memory GEMMs +// + +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +cooperative_gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, + BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) +{ + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using TypeA = typename TA::value_type; + using TypeB = typename TB::value_type; + using TypeC = typename TC::value_type; + + static_assert(is_same_v>, TypeA>, + "ALoadTransformOp functor must accept and return value of type TA::value_type"); + static_assert(is_same_v>, TypeB>, + "BLoadTransformOp functor must accept and return value of type TB::value_type"); + + // Original, static size of the problem + auto M = size<0>(sC); + auto N = size<1>(sC); + auto K = size<1>(sA); + + // Block size of the compute tile + auto BLK_M = tile_size<0>(thr_mma); + auto BLK_N = tile_size<1>(thr_mma); + auto BLK_K = tile_size<2>(thr_mma); + + // Compute the "residues" + auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M] + auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N] + auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0] + + // Shift the origin so k_residue is zeroth tile + sA.data() = &sA(0,k_residue); + sB.data() = &sB(0,k_residue); + +#if 0 + if (thread0()) { + printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M)); + printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N)); + printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K)); + } +#endif + + // + // MMA Partitioning + // + + // Round the layout extents up to BLK_X + Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K)); + Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K)); + Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N)); + +#if 0 + if (thread0()) { + print("rounded_sA: "); print(rounded_sA); print("\n"); + print("rounded_sB: "); print(rounded_sB); print("\n"); + print("rounded_sC: "); print(rounded_sC); print("\n"); + } +#endif + + // Partition the sA and sB tiles across the threads for the MMA + Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K) + Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N) + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + +#if 0 + if (thread0()) { + print("tCsA: "); print(tCsA); print("\n"); + print("tCsB: "); print(tCsB); print("\n"); + print("tCsC: "); print(tCsC); print("\n"); + print("tCrA: "); print(tCrA); print("\n"); + print("tCrB: "); print(tCrB); print("\n"); + print("tCrC: "); print(tCrC); print("\n"); + } +#endif + + // + // PREDICATION + // + + // Allocate the preds for only the MMA-mode of tCsA and tCsB + Tensor tCpA = make_tensor(size<0>(tCsA)); + Tensor tCpB = make_tensor(size<0>(tCsB)); + + // Create coordinate tensors on a single compute block for predication + Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k) + + // Repeat partitioning with thr_mma + Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k) + Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k) + + // Populate the m and n predicates + CUTE_UNROLL + for (int i = 0; i < size(tCpA); ++i) { + tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue); + } + CUTE_UNROLL + for (int i = 0; i < size(tCpB); ++i) { + tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue); + } + +#if 0 + printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n", + threadIdx.x, + int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)), + int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0))); +#endif + + // + // PREFETCH k_block = 0 (with k-predication) + // + + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I + if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k + CUTE_UNROLL + for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m + tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; + } + } + } + + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I + if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k + CUTE_UNROLL + for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n + tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; + } + } + } + // + // MAINLOOP + // + + // Clear accumulators + clear(tCrC); + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + // static-if load the next k_block. No k-predication required on these loads. + if (k_block < K_BLOCK_MAX-1) + { + // Load the next k_block + int k_next = k_block + 1; + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m + tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; + } + } + + CUTE_UNROLL + for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n + tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; + } + } + } + + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } + + // + // Epilogue + // + + Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n) + + const bool isBetaZero = (beta == Beta{}); + + // Custom axpby_if for now + CUTE_UNROLL + for (int m = 0; m < size<1>(tCsC); ++m) + { + CUTE_UNROLL + for (int n = 0; n < size<2>(tCsC); ++n) + { + CUTE_UNROLL + for (int i = 0; i < size<0>(tCsC); ++i) + { + if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) && + (n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue)) + { + tCsC(i,m,n) = isBetaZero ? alpha * static_cast(tCrC(i,m,n)) : alpha * static_cast(tCrC(i,m,n)) + beta * static_cast(tCsC(i,m,n)); + } + } + } + } +} + +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +cooperative_gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC) +{ + cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); +} + +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, + BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) +{ + cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op); +} + +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC) +{ + cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/copy.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/copy.hpp new file mode 100644 index 00000000..50a092d0 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/copy.hpp @@ -0,0 +1,393 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor && dst) +{ + return copy(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_vec(Tensor const& src, + Tensor && dst) +{ + return copy_vec(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor && dst) +{ + return copy_aligned(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& copy_policy, + PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(copy_policy, pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& copy_policy, + Tensor const& src, + Tensor && dst) +{ + return copy(copy_policy, src, dst); +} + +// +// copy_if -- Predicated Copy +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + auto copy_op = select_elementwise_copy(src, dst); + + CUTE_UNROLL + for (int i = 0; i < size(src); ++i) { + if (pred(i)) { + copy_op.copy(src(i), dst(i)); + } + } +} + +// +// copy_if -- Predicated CopyAtom +// + +namespace detail { + +// Trait that detects if atom's traits has a member function with(bool) +template +constexpr bool has_with_bool = false; + +template +constexpr bool has_with_bool().with(declval()))>> = true; + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + PredTensor const& pred, // (Rest...) + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(src_v); ++i) { + // If copy traits can be transformed with a predicate value, do it, otherwise branch here + if constexpr (detail::has_with_bool>) { + copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i)); + } else { + if (pred(i)) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } + } + } + } +} + +// +// copy_vec -- attempt vectorized copy with VecType +// + +template +CUTE_HOST_DEVICE +void +copy_vec(Tensor const& src, + Tensor & dst) +{ + static_assert(sizeof_bits_v >= 8 && sizeof_bits_v % 8 == 0, + "Expected a vectorization type of at least a byte."); + using SrcType = typename SrcEngine::element_type; + using DstType = typename DstEngine::element_type; + if constexpr (sizeof_bits_v == sizeof_bits_v && + sizeof_bits_v > sizeof_bits_v) + { + // Preserve volatility of Src/Dst types. + using SrcVecType = conditional_t, VecType const volatile, VecType const>; + using DstVecType = conditional_t, VecType volatile, VecType >; + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); + +#if 0 + if (thread0()) { + print("copy_vec<%db> -- vectorizing copy:\n", int(sizeof_bits_v)); + print(" "); print(src); print(" => "); print(src_v); print("\n"); + print(" "); print(dst); print(" => "); print(dst_v); print("\n"); + } +#endif + + return copy_if(TrivialPredTensor{}, src_v, dst_v); + } else { +#if 0 + if (thread0()) { + print("copy_vec<%db> -- NOT vectorizing copy:\n", int(sizeof_bits_v)); + print(" "); print(src); print("\n"); + print(" "); print(dst); print("\n"); + } +#endif + + return copy_if(TrivialPredTensor{}, src, dst); + } +} + +// +// copy -- CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const& copy_atom, + Tensor const& src, + Tensor & dst) +{ + return copy_if(copy_atom, TrivialPredTensor{}, src, dst); +} + +////////////////////////////////////////// +// Special Auto-Vectorizing Overloads +////////////////////////////////////////// + +// Specialization for AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(AutoVectorizingCopyWithAssumedAlignment const&, + Tensor const& src, + Tensor & dst) +{ + constexpr int vec_elem = decltype(max_common_vector(src, dst))::value; + + constexpr int src_bits = sizeof_bits::value; + // When layouts are static, accept vec_bits up to 128 + // When layouts are dynamic, accept vec_bits up to MaxVecBits + constexpr int vec_bits = (is_static::value && is_static::value) ? + cute::min(vec_elem * src_bits, 128) : + cute::min(vec_elem * src_bits, MaxVecBits); + +#if 0 + if (thread0()) { + print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits); + print(" "); print(src); print("\n"); + print(" "); print(dst); print("\n"); + } +#endif + + if constexpr (vec_elem > 1 && vec_bits >= 8) { + return copy_vec>(src, dst); + } else { + return copy_if(TrivialPredTensor{}, src, dst); + } +} + +// Auto-vectorizing copy for static layouts +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopy{}, src, dst); +} + +// Auto-vectorizing copy with assumed alignment of dynamic layout strides up to 128bit. +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); +} + +// Specializaton for Atom AutoVectorizingCopy +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopy{}, src, dst); +} + +// Specializaton for Atom AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, Args...> const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +copy(Copy_Traits const& atom, // Copy_Traits may or may not have the memory barrier in it already + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + static_assert(sizeof_bits::value == sizeof_bits::value); + static_assert((is_gmem::value && is_smem::value) || + (is_smem::value && is_gmem::value), + "Bulk Copy only supports gmem -> smem or smem -> gmem movement."); + // G2S or S2G dispatch + using BULK_COPY_OP = conditional_t::value, + SM90_BULK_COPY_G2S, + SM90_BULK_COPY_S2G>; + + // Find the common subtensor of src and dst + auto tiler = max_common_layout(src, dst); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + using BulkAtom = Copy_Atom, CT_Args...>, SrcType>; + auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; }); + +#if 0 + if (thread0()) { + print("copy blkcp -- found a max_common_layout of "); print(tiler); print("\n"); + print(" "); print(src); print("\n"); + print(" "); print(dst); print("\n"); + } +#endif + + return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, CA_Args...> const& atom, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast const&>(atom), src, dst); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/fill.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/fill.hpp new file mode 100644 index 00000000..52060651 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/fill.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +fill(Tensor&& tensor, T const& value) +{ + return fill(tensor, value); +} + +namespace detail +{ + +// Prefer fill(tensor.data(), value), if possible +template +CUTE_HOST_DEVICE +auto +fill(Tensor& tensor, T const& value, prefer<1>) + -> decltype(fill(tensor.data(), value)) +{ + fill(tensor.data(), value); +} + +// Default implementation +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value, prefer<0>) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = value; + } +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value) +{ + return detail::fill(tensor, value, prefer<1>{}); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/functional.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/functional.hpp new file mode 100644 index 00000000..8e7a58a5 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/functional.hpp @@ -0,0 +1,291 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +/** C++14 extensions */ + +namespace cute { + +/**************/ +/** Identity **/ +/**************/ + +struct identity { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg); + } +}; + +template +struct constant_fn { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&&...) const { + return r_; + } + R r_; +}; + +/***********/ +/** Unary **/ +/***********/ + +#define CUTE_LEFT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP static_cast(arg); \ + } \ + } +#define CUTE_RIGHT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return static_cast(arg) OP ; \ + } \ + } +#define CUTE_NAMED_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP (static_cast(arg)); \ + } \ + } + +CUTE_LEFT_UNARY_OP(unary_plus, +); +CUTE_LEFT_UNARY_OP(negate, -); +CUTE_LEFT_UNARY_OP(bit_not, ~); +CUTE_LEFT_UNARY_OP(logical_not, !); +CUTE_LEFT_UNARY_OP(dereference, *); +CUTE_LEFT_UNARY_OP(address_of, &); +CUTE_LEFT_UNARY_OP(pre_increment, ++); +CUTE_LEFT_UNARY_OP(pre_decrement, --); + +CUTE_RIGHT_UNARY_OP(post_increment, ++); +CUTE_RIGHT_UNARY_OP(post_decrement, --); + +CUTE_NAMED_UNARY_OP(abs_fn, abs); +CUTE_NAMED_UNARY_OP(conjugate, cute::conj); + +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP +#undef CUTE_NAMED_UNARY_OP + +template +struct shift_right_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) >> Shift; + } +}; + +template +struct shift_left_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) << Shift; + } +}; + +/************/ +/** Binary **/ +/************/ + +#define CUTE_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return static_cast(lhs) OP static_cast(rhs); \ + } \ + } +#define CUTE_NAMED_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return OP (static_cast(lhs), static_cast(rhs)); \ + } \ + } + + +CUTE_BINARY_OP(plus, +); +CUTE_BINARY_OP(minus, -); +CUTE_BINARY_OP(multiplies, *); +CUTE_BINARY_OP(divides, /); +CUTE_BINARY_OP(modulus, %); + +CUTE_BINARY_OP(plus_assign, +=); +CUTE_BINARY_OP(minus_assign, -=); +CUTE_BINARY_OP(multiplies_assign, *=); +CUTE_BINARY_OP(divides_assign, /=); +CUTE_BINARY_OP(modulus_assign, %=); + +CUTE_BINARY_OP(bit_and, &); +CUTE_BINARY_OP(bit_or, |); +CUTE_BINARY_OP(bit_xor, ^); +CUTE_BINARY_OP(left_shift, <<); +CUTE_BINARY_OP(right_shift, >>); + +CUTE_BINARY_OP(bit_and_assign, &=); +CUTE_BINARY_OP(bit_or_assign, |=); +CUTE_BINARY_OP(bit_xor_assign, ^=); +CUTE_BINARY_OP(left_shift_assign, <<=); +CUTE_BINARY_OP(right_shift_assign, >>=); + +CUTE_BINARY_OP(logical_and, &&); +CUTE_BINARY_OP(logical_or, ||); + +CUTE_BINARY_OP(equal_to, ==); +CUTE_BINARY_OP(not_equal_to, !=); +CUTE_BINARY_OP(greater, >); +CUTE_BINARY_OP(less, <); +CUTE_BINARY_OP(greater_equal, >=); +CUTE_BINARY_OP(less_equal, <=); + +CUTE_NAMED_BINARY_OP(max_fn, cute::max); +CUTE_NAMED_BINARY_OP(min_fn, cute::min); + +#undef CUTE_BINARY_OP +#undef CUTE_NAMED_BINARY_OP + +/**********/ +/** Fold **/ +/**********/ + +#define CUTE_FOLD_OP(NAME,OP) \ + struct NAME##_unary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (t OP ...); \ + } \ + }; \ + struct NAME##_unary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (... OP t); \ + } \ + }; \ + struct NAME##_binary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (t OP ... OP u); \ + } \ + }; \ + struct NAME##_binary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (u OP ... OP t); \ + } \ + } + +CUTE_FOLD_OP(plus, +); +CUTE_FOLD_OP(minus, -); +CUTE_FOLD_OP(multiplies, *); +CUTE_FOLD_OP(divides, /); +CUTE_FOLD_OP(modulus, %); + +CUTE_FOLD_OP(plus_assign, +=); +CUTE_FOLD_OP(minus_assign, -=); +CUTE_FOLD_OP(multiplies_assign, *=); +CUTE_FOLD_OP(divides_assign, /=); +CUTE_FOLD_OP(modulus_assign, %=); + +CUTE_FOLD_OP(bit_and, &); +CUTE_FOLD_OP(bit_or, |); +CUTE_FOLD_OP(bit_xor, ^); +CUTE_FOLD_OP(left_shift, <<); +CUTE_FOLD_OP(right_shift, >>); + +CUTE_FOLD_OP(bit_and_assign, &=); +CUTE_FOLD_OP(bit_or_assign, |=); +CUTE_FOLD_OP(bit_xor_assign, ^=); +CUTE_FOLD_OP(left_shift_assign, <<=); +CUTE_FOLD_OP(right_shift_assign, >>=); + +CUTE_FOLD_OP(logical_and, &&); +CUTE_FOLD_OP(logical_or, ||); + +CUTE_FOLD_OP(equal_to, ==); +CUTE_FOLD_OP(not_equal_to, !=); +CUTE_FOLD_OP(greater, >); +CUTE_FOLD_OP(less, <); +CUTE_FOLD_OP(greater_equal, >=); +CUTE_FOLD_OP(less_equal, <=); + +#undef CUTE_FOLD_OP + +/**********/ +/** Meta **/ +/**********/ + +template +struct bound_fn { + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(T&& arg) { + return fn_(arg_, static_cast(arg)); + } + + Fn fn_; + Arg arg_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +bind(Fn const& fn, Arg const& arg) { + return bound_fn{fn, arg}; +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/gemm.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/gemm.hpp new file mode 100644 index 00000000..27c32216 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/gemm.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +#include + +#include + +/** The gemm algorithm takes four (or three) tensors and computes + * D = A * B + C + * It dispatches based on the number of modes each tensor has: + * + * 1. `(V) x (V) => (V)`. + * The element-wise product of vectors. Dispatches to FMA or MMA. + * 2. `(M) x (N) => (M,N)`. + * The outer product of vectors. Dispatches to [3] with new mode K=(1). + * 3. `(M,K) x (N,K) => (M,N)`. + * The product of matrices. Dispatches to [5] with MMA vector-mode V. + * 4. `(V,M) x (V,N) => (V,M,N)`. + * The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n). + * 5. `(V,M,K) x (V,N,K) => (V,M,N)`. + * The batched product of matrices. Dispatches to [4] for each (k). + */ + +namespace cute +{ + +// +// Three arguments to four +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(mma, C, A, B, C); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(D, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(mma, C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(mma, D, A, B, C); +} + +// +// Default MMA is UniversalFMA +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + using MMA = MMA_Atom::value_type, + typename Tensor::value_type, + typename Tensor::value_type, + typename Tensor::value_type>>; + + return gemm(MMA{}, D, A, B, C); +} + +// +// Thread-Local Register-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 1 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V) Logical data + Tensor const& A, // (V) Logical data + Tensor const& B, // (V) Logical data + Tensor const& C) // (V) Logical data +{ + // No static assertions on (V), MMA checks compatibility + mma.call(D, A, B, C); +} + +// Dispatch [2]: (M) x (N) => (M,N) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M) Logical data + Tensor const& B, // (N) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + gemm(mma, + D, // (M,N) + make_tensor(A.data(), append<2>(A.layout())), // (M,1) + make_tensor(B.data(), append<2>(B.layout())), // (N,1) + C); // (M,N) +} + +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M) Logical data + Tensor const& B, // (V,N) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto M = size<1>(A); + auto N = size<1>(B); + // REGISTER .reuse OPTIMIZATIONS + // 64-bit traversal specialization -- serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } +#else + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } +#endif + } else + // 32-bit traversal specialization -- kinked serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major kinked serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; m += 2) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 2) ? N-1-n : n; + gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns)); + + if (m+1 < M) { + gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns)); + } + } + } +#else + // Col-major kinked serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; n += 2) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + // Kinked serpentine traversal for maximum register reuse + int ms = (n & 2) ? M-1-m : m; + gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0)); + + if (n+1 < N) { + gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1)); + } + } + } +#endif + } else + // 64-bit + 32-bit traversal order -- keep A (64-bit) in the outer loop and serpentine B + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) { + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } + } else + // 32-bit + 64-bit traversal order -- keep B (64-bit) in the outer loop and serpentine A + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } else + // Fallback to serpentine loop + { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_rmem::value && + BLayout::rank == 3 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) { + gemm(mma, D, A(_,_,k), B(_,_,k), C); + } +} + +// +// Thread-Local Shared-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +// Dispatch [2]: (M) x (N) => (M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_smem::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_smem::value && + BLayout::rank == 3 && is_smem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + + auto rA = MMA_Atom::make_fragment_A(A); + auto rB = MMA_Atom::make_fragment_B(B); + + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) + { + copy(A(_,_,k), rA(_,_,k)); + copy(B(_,_,k), rB(_,_,k)); + // Thread-level register gemm for k + gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + } +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/prefer.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/prefer.hpp new file mode 100644 index 00000000..a69e5042 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/prefer.hpp @@ -0,0 +1,46 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace cute +{ + +// Infinite types that inherit from each other +template +struct prefer : prefer {}; + +template <> +struct prefer<0> {}; + +// Can be used to preferencially overload implementations +// Higher N in prefer have higher priority. + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/prefetch.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/prefetch.hpp new file mode 100644 index 00000000..47aefa87 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/prefetch.hpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include + +namespace cute +{ + +// +// Prefetch global tensors into L2 +// + +template +CUTE_HOST_DEVICE +void +cooperative_prefetch(uint32_t const& tid, + Tensor const& src) +{ + static_assert(is_gmem::value, "Expected global tensor for prefetch"); + + constexpr int V = decltype(max_common_vector(src, src))::value; + + if constexpr (V > 1) { + // L2 sector is 32B, default fetch granularity is 64B + using VecType = conditional_t<(V * sizeof_bits_v) < (FetchBytes * 8), + ArrayEngine, + uint8_t[FetchBytes] >; + + Tensor src_v = recast(src); + CUTE_UNROLL + for (int i = tid; i < size(src_v); i += NumThreads) { + prefetch(raw_pointer_cast(&src_v(i))); + } + } else { + CUTE_UNROLL + for (int i = tid; i < size(src); i += NumThreads) { + prefetch(raw_pointer_cast(&src(i))); + } + } +} + +template +CUTE_HOST_DEVICE +void +prefetch(Tensor const& src) +{ + return cooperative_prefetch<1>(0, src); +} + +// Prefetch with copy atom +namespace detail { + +template +constexpr bool has_prefetch = false; + +template +constexpr bool has_prefetch> = true; + +template +constexpr bool is_prefetch = false; + +template +constexpr bool is_prefetch> = is_same_v; + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CA_Args...> const& atom, + Tensor const& src) +{ + if constexpr (detail::has_prefetch) { + using Prefetch_Traits = Copy_Traits; + using Prefetch_Atom = Copy_Atom; + Prefetch_Atom prefetch_atom{atom}; + auto& dst = const_cast&>(src); // dst is ignored for prefetch atoms + return copy(prefetch_atom, src, dst); + } else { + return prefetch(src); + } +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Traits const& atom, + Tensor const& src) +{ + using SrcType = typename SrcEngine::value_type; + static_assert(is_gmem::value, "Expected global tensor for L2 prefetch"); + + auto tiler = max_common_layout(src, src); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + auto bulk_atom = Copy_Atom>, SrcType>{}; + + return prefetch(bulk_atom, logical_divide(src, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CA_Args...> const& atom, + Tensor const& src) +{ + return prefetch(static_cast const&>(atom), src); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/tensor_algorithms.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/tensor_algorithms.hpp new file mode 100644 index 00000000..0faf68a6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/tensor_algorithms.hpp @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** Common algorithms on (hierarchical) tensors */ + +#pragma once + +#include + +#include + +namespace cute +{ + +// +// for_each +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor const& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor&& tensor, UnaryOp&& op) +{ + return for_each(tensor, op); +} + +// +// transform +// + +// Similar to std::transform but does not return number of elements affected +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor&& tensor, UnaryOp&& op) +{ + return transform(tensor, op); +} + +// Similar to std::transform transforms one tensors and assigns it to another +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor & tensor_out, + UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in); ++i) { + tensor_out(i) = op(tensor_in(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor && tensor_out, + UnaryOp&& op) +{ + return transform(tensor_in, tensor_out, op); +} + +// Similar to std::transform with a binary operation +// Takes two tensors as input and one tensor as output. +// Applies the binary_op to tensor_in1 and tensor_in2 and +// assigns it to tensor_out +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor & tensor_out, + BinaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in1); ++i) { + tensor_out(i) = op(tensor_in1(i), tensor_in2(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor && tensor_out, + BinaryOp&& op) +{ + return transform(tensor_in1, tensor_in2, tensor_out, op); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/algorithm/tuple_algorithms.hpp b/server/punica_kernels/include/cutlass/cute/algorithm/tuple_algorithms.hpp new file mode 100644 index 00000000..6c97e10e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/algorithm/tuple_algorithms.hpp @@ -0,0 +1,1091 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +/// @file tuple_algorithms.hpp +/// @brief Common algorithms on (hierarchical) tuples +/// +/// Code guidelines and style preferences: +/// +/// For perfect forwarding, don't use std::forward, because it may not +/// be defined in device code when compiling with NVRTC. Instead, use +/// `static_cast(parameter_name)`. +/// +/// CuTe generally does not bother forwarding functions, as +/// reference-qualified member functions are rare in this code base. +/// +/// Throughout CUTLASS, cute::make_tuple always needs to be called +/// namespace-qualified, EVEN If inside the cute namespace and/or in +/// scope of a "using namespace cute" declaration. Otherwise, the +/// compiler may select std::make_tuple instead of cute::make_tuple, +/// due to argument-dependent lookup. Two problems may result from +/// that. +/// +/// 1. Functions have an unexpected return type (std::tuple instead of +/// cute::tuple), so functions that take cute::tuple parameters +/// fail to compile (generally inside functions that have template +/// parameters expected to be cute::tuple). +/// +/// 2. std::tuple does not have the required __host__ __device__ +/// markings, so the CUDA compiler complains if you use it in +/// device code. +/// +/// cute::make_tuple will occur more often than std::make_tuple would +/// in modern C++ code, because cute::tuple's design deprioritizes +/// correct operation of CTAD (constructor template argument +/// deduction) in favor of implementation simplicity. + +namespace cute +{ + +// +// Apply (Unpack) +// (t, f) => f(t_0,t_1,...,t_n) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f, seq) +{ + return f(get(static_cast(t))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f) +{ + return detail::apply(static_cast(t), f, tuple_seq{}); +} + +// +// Transform Apply +// (t, f, g) => g(f(t_0),f(t_1),...) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T&& t, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)), + get(static_cast(t2)))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T&& t, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t))); + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1))); + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); + } +} + +// +// For Each +// (t, f) => f(t_0),f(t_1),...,f(t_n) +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +for_each_leaf(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Transform +// (t, f) => (f(t_0),f(t_1),...,f(t_n)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1, t2); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); }); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// find and find_if +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f, seq) +{ + if constexpr (decltype(f(get(t)))::value) { + return cute::C{}; + } else + if constexpr (sizeof...(Is) == 0) { + return cute::C{}; + } else { + return find_if(t, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::find_if(t, f, tuple_seq{}); + } else { + return cute::C{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +find(T const& t, X const& x) +{ + return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false +} + +template +CUTE_HOST_DEVICE constexpr +auto +any_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +all_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +none_of(T const& t, F&& f) +{ + return not any_of(t, f); +} + +// +// Filter +// (t, f) => +// + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T const& t, F&& f) +{ + return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, F&& f) +{ + return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +// +// Fold (Reduce, Accumulate) +// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) +// + +namespace detail { + +// This impl compiles much faster than cute::apply and variadic args +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +fold(T&& t, V&& v, F&& f, seq<>) +{ + return static_cast(v); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +fold(T&& t, V&& v, F&& f, seq) +{ + if constexpr (sizeof...(Is) == 0) { + return f(static_cast(v), get(static_cast(t))); + } else { + return fold(static_cast(t), + f(static_cast(v), get(static_cast(t))), + f, + seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V&& v, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), + static_cast(v), + f, + tuple_seq{}); + } else { + return f(static_cast(v), static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +fold_first(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), + get<0>(static_cast(t)), + f, + make_range<1,tuple_size>::value>{}); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// front, back, take, select, unwrap +// + +// Get the first non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +front(T&& t) +{ + if constexpr (is_tuple>::value) { + return front(get<0>(static_cast(t))); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Get the last non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +back(T&& t) +{ + if constexpr (is_tuple>::value) { + constexpr int N = tuple_size>::value; + + // MSVC needs a bit of extra help here deducing return types. + // We help it by peeling off the nonrecursive case a level "early." + if constexpr (! is_tuple(static_cast(t)))>>::value) { + return get(static_cast(t)); + } else { + return back(get(static_cast(t))); + } + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Takes the elements in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(T const& t) +{ + return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); +} + +// +// Select tuple elements with given indices. +// + +template +CUTE_HOST_DEVICE constexpr +auto +select(T const& t) +{ + return cute::make_tuple(get(t)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +select(T const& t, Indices const& indices) +{ + if constexpr (is_tuple::value) { + return cute::transform(indices, [&t](auto i) { return select(t, i); }); + } else { + static_assert(is_static::value, "Order must be static"); + return get(t); + } +} + +// Wrap non-tuples into rank-1 tuples or forward +template +CUTE_HOST_DEVICE constexpr +auto +wrap(T const& t) +{ + if constexpr (is_tuple::value) { + return t; + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple +template +CUTE_HOST_DEVICE constexpr +auto +unwrap(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 1) { + return unwrap(get<0>(t)); + } else { + return t; + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Flatten and Unflatten +// + +template +struct is_flat : true_type {}; + +template +struct is_flat> : bool_constant<(true && ... && (not is_tuple::value))> {}; + +// Flatten a hierarchical tuple to a tuple of depth one +// and wrap non-tuples into a rank-1 tuple. +template +CUTE_HOST_DEVICE constexpr +auto +flatten_to_tuple(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Flatten a hierarchical tuple to a tuple of depth one +// and leave non-tuple untouched. +template +CUTE_HOST_DEVICE constexpr +auto +flatten(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + if constexpr (is_tuple::value) { + return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { + auto [result, remaining_tuple] = v; + auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); + return cute::make_tuple(append(result, sub_result), sub_tuple); + }); + } else { + return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Unflatten a flat tuple into a hierarchical tuple +// @pre flatten(@a flat_tuple) == @a flat_tuple +// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) +// @post congruent(@a result, @a target_profile) +// @post flatten(@a result) == @a flat_tuple +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); + CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); + return unflatten_tuple; +} + +// +// insert and remove and replace +// + +namespace detail { + +// Shortcut around cute::tuple_cat for common insert/remove/repeat cases +template +CUTE_HOST_DEVICE constexpr +auto +construct(T const& t, X const& x, seq, seq, seq) +{ + return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); +} + +} // end namespace detail + +// Insert x into the Nth position of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +insert(T const& t, X const& x) +{ + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); +} + +// Remove the Nth element of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +remove(T const& t) +{ + return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); +} + +// Replace the Nth element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace(T const& t, X const& x) +{ + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); +} + +// Replace the first element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_front(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the last element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_back(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs of tuple_size N +// + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_repeat(X const& x) +{ + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); +} + +// +// Make repeated Xs of rank N +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat(X const& x) +{ + if constexpr (N == 1) { + return x; + } else { + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs the same profile as tuple T +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat_like(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return repeat_like(a,x); }); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Group the elements [B,E) of a T into a single element +// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) +// => T<_1,_2,T<_3,_4>,_5,_6>{} +template +CUTE_HOST_DEVICE constexpr +auto +group(T const& t) +{ + if constexpr (not is_tuple::value) { + if constexpr (E == -1) { + return group(t); + } else { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range{}); + } + } else + if constexpr (E == -1) { + return group::value>(t); + } else + if constexpr (B <= E) { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range::value>{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Extend a T to rank N by appending/prepending an element +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); + } else { + return cute::make_tuple(a, x); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + static_assert(N > 1); + return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); + } else { + return cute::make_tuple(x, a); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Inclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f, seq) +{ + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v_next + auto t_next = replace(t, v_next); + +#if 0 + std::cout << "ISCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + if constexpr (sizeof...(Is) == 0) { + return t_next; + } else { + return iscan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f) +{ + return detail::iscan(t, v, f, tuple_seq{}); +} + +// +// Exclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f, seq) +{ + if constexpr (sizeof...(Is) == 0) { + // Replace I with v + return replace(t, v); + } else { + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v + auto t_next = replace(t, v); + +#if 0 + std::cout << "ESCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + // Recurse + return escan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f) +{ + return detail::escan(t, v, f, tuple_seq{}); +} + +// +// Zip (Transpose) +// + +// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input +// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip_(Ts const&... ts) +{ + return cute::make_tuple(get(ts)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t, seq, seq) +{ + static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); + return cute::make_tuple(zip_(get(t)...)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple>::value) { + return detail::zip(t, tuple_seq{}, tuple_seq>{}); + } else { + return cute::make_tuple(t); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// Convenient to pass them in separately +template +CUTE_HOST_DEVICE constexpr +auto +zip(T0 const& t0, T1 const& t1, Ts const&... ts) +{ + return zip(cute::make_tuple(t0, t1, ts...)); +} + +// +// zip2_by -- A guided zip for rank-2 tuples +// Take a tuple like ((A,a),((B,b),(C,c)),d) +// and produce a tuple ((A,(B,C)),(a,(b,c),d)) +// where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide, seq, seq) +{ + // zip2_by produces the modes like ((A,a),(B,b),...) + auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); + + // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) + return cute::make_tuple(cute::make_tuple(get<0>(get(split))...), + cute::make_tuple(get<1>(get(split))..., get(t)...)); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide) +{ + if constexpr (is_tuple::value) { + constexpr int TR = tuple_size::value; + constexpr int GR = tuple_size::value; + static_assert(TR >= GR, "Mismatched ranks"); + return detail::zip2_by(t, guide, + make_range< 0, GR>{}, + make_range{}); + } else { + static_assert(tuple_size::value == 2, "Mismatched ranks"); + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +/// @return A tuple of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +auto +reverse(T const& t) +{ + if constexpr (is_tuple::value) { + return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq{}); + } else { + return t; + } +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/cluster_sm90.hpp b/server/punica_kernels/include/cutlass/cute/arch/cluster_sm90.hpp new file mode 100644 index 00000000..27a34d77 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/cluster_sm90.hpp @@ -0,0 +1,245 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) +# define CUTE_ARCH_CLUSTER_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED +#endif + +namespace cute { + +CUTE_DEVICE void cluster_arrive_relaxed() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_arrive() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_wait() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.wait.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_sync() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + cluster_arrive(); + cluster_wait(); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +// Returns the dim3 grid size in terms of number of clusters. +CUTE_DEVICE dim3 cluster_grid_dims() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of gridDim with __CUDA_ARCH__. + return gridDim; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_grid_dims() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the dim3 cluster rank in the grid. +CUTE_DEVICE dim3 cluster_id_in_grid() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of blockIdx with __CUDA_ARCH__. + return blockIdx; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_id_in_grid() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the relative dim3 block rank local to the cluster. +CUTE_DEVICE dim3 block_id_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {0,0,0}; +#endif +} + +// Returns the dim3 cluster shape. +CUTE_DEVICE dim3 cluster_shape() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {1,1,1}; +#endif +} + +// Get 1D ctaid in a cluster. +CUTLASS_DEVICE uint32_t block_rank_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t rank; + asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); + return rank; +#else + return 0; +#endif +} + +// Set the destination block-ID in cluster for a given SMEM Address +CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t result; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(result) + : "r"(smemAddr), "r"(rank)); + return result; +#else + return smemAddr; +#endif +} + +// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false. +CUTE_HOST_DEVICE uint32_t elect_one_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +#elif defined(__CUDA_ARCH__) + return (threadIdx.x % 32) == 0; +#else + return true; +#endif +} + +struct ElectOneLaneIdReturnType { + uint32_t is_leader; + uint32_t leader_lane_id; +}; + +CUTE_HOST_DEVICE +ElectOneLaneIdReturnType +elect_one_leader_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return {pred, laneid}; +#elif defined(__CUDA_ARCH__) + return {(threadIdx.x % 32) == 0, 0}; +#else + return {true, 0}; +#endif +} + +// Store value to remote shared memory in the cluster +CUTE_DEVICE +void +store_shared_remote(uint32_t value, uint32_t smem_addr, uint32_t mbarrier_addr, uint32_t dst_cta_rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t dsmem_addr = set_block_rank(smem_addr, dst_cta_rank); + uint32_t remote_barrier_addr = set_block_rank(mbarrier_addr, dst_cta_rank); + asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];" + : : "r"(dsmem_addr), "r"(value), "r"(remote_barrier_addr)); +#endif +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/copy.hpp b/server/punica_kernels/include/cutlass/cute/arch/copy.hpp new file mode 100644 index 00000000..b85e6a20 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/copy.hpp @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Direct Copy for any type +// + +template +struct UniversalCopy +{ + using SRegisters = S[1]; + using DRegisters = D[1]; + + template + CUTE_HOST_DEVICE static constexpr void + copy(S_ const& src, + D_ & dst) + { + dst = static_cast(static_cast(src)); + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE static constexpr void + copy(S_ const& src, + D_ && dst) + { + UniversalCopy::copy(src, dst); + } +}; + +// +// Placeholder for the copy algorithm's stronger auto-vectorizing behavior +// that assumes alignment of dynamic layouts up to MaxVecBits +// + +template +struct AutoVectorizingCopyWithAssumedAlignment + : UniversalCopy> +{ + static_assert(MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be 8 or 16 or 32 or 64 or 128 for alignment and performance."); +}; + +// +// Placeholder for the copy algorithm's default auto-vectorizing behavior +// that does not assume alignment of dynamic layouts +// + +using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<8>; + +// Alias +using DefaultCopy = AutoVectorizingCopy; + + +// +// Global memory prefetch into L2 +// + +CUTE_HOST_DEVICE static void +prefetch(void const* gmem_ptr) +{ +#if defined(__CUDA_ARCH__) + asm volatile("prefetch.global.L2 [%0];\n" : : "l"(gmem_ptr) : "memory"); +#endif +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/copy_sm75.hpp b/server/punica_kernels/include/cutlass/cute/arch/copy_sm75.hpp new file mode 100644 index 00000000..3d3d37ac --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/copy_sm75.hpp @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if defined(__clang__) && defined(__CUDA__) + // ldmatrix PTX instructions added in Clang 14: https://reviews.llvm.org/D107046 + // ... but will not work until Clang 15: + // * https://reviews.llvm.org/D121666 + // * https://reviews.llvm.org/D126846 + #define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15) +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // ldmatrix PTX instruction added in CUDA 10.2+ + #define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED) + #define CUTE_ARCH_LDSM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_ENABLED) + #define CUTE_ARCH_LDSM_SM75_ENABLED (CUTE_ARCH_LDSM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 +#endif + +namespace cute +{ + +struct SM75_U32x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x2_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x4_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +// +// Legacy LDSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_ldsm(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_ldsm_trans(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/copy_sm80.hpp b/server/punica_kernels/include/cutlass/cute/arch/copy_sm80.hpp new file mode 100644 index 00000000..43e3d0d7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/copy_sm80.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED +#endif + +namespace cute +{ + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +CUTE_HOST_DEVICE +void +cp_async_fence() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Blocks until all but N previous cp.async.commit_group operations have committed. +template +CUTE_HOST_DEVICE +void +cp_async_wait() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); + } +#endif +} + +template +CUTE_HOST_DEVICE +void +cp_async_wait(Int) +{ + return cp_async_wait(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/copy_sm90.hpp b/server/punica_kernels/include/cutlass/cute/arch/copy_sm90.hpp new file mode 100644 index 00000000..e5684ec4 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/copy_sm90.hpp @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) && \ + ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + +namespace cute +{ + +struct SM90_U32x1_STSM_N +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t & smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x2_STSM_N +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x4_STSM_N +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x2_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x4_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x8_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +// +// Legacy STSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_stsm(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_stsm_trans(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cute/arch/copy_sm90_desc.hpp b/server/punica_kernels/include/cutlass/cute/arch/copy_sm90_desc.hpp new file mode 100644 index 00000000..856d4dd5 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/copy_sm90_desc.hpp @@ -0,0 +1,342 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +#include + +#include +#include + +#include +#include +#include +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////////////////////////////// +/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns +/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels) +/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction) +////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Initialize barrier present in shared memory +CUTE_HOST_DEVICE +void +initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int thread_count = 1) // Thread count expected to arrive/wait on this barrier +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(thread_count)); +#endif +} + +// Set the number of bytes transfered per transaction and perform an arrive operation as well +CUTE_HOST_DEVICE +void +set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + uint32_t bytes) // Number of bytes transfered by per TMA transaction +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(bytes)); +#endif +} + +// Barrier wait +CUTE_HOST_DEVICE +void +wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int phase_bit) // Current phase bit the barrier waiting to flip +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(smem_int_ptr), + "r"(phase_bit)); + +#endif +} + +// Barrier arrive +CUTE_HOST_DEVICE +void +arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .b64 state; \n" + "mbarrier.arrive.shared::cta.b64 state, [%0];\n" + "}\n" + :: "r"(smem_int_ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TMA Descriptor and utilities +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMA { + +enum class SmemSwizzleBits : uint8_t { + DISABLE = 0, + B32 = 1, + B64 = 2, + B128 = 3, +}; + +#if (__CUDACC_VER_MAJOR__ >= 12) + +#if !defined(__CUDACC_RTC__) +/// @return The TMA descriptor datatype enum corresponding to T. +template +inline CUtensorMapDataType +to_CUtensorMapDataType() { + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } +} + +inline CUtensorMapSwizzle +to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { + switch (t) { + default: assert(false && "Unknown SmemSwizzleBits!"); + case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE; + case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B; + case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B; + case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; + } +} +#endif // !defined(__CUDACC_RTC__) + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + +} // end namespace TMA + +#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + using TmaDescriptor = CUtensorMap; + using Im2ColTmaDescriptor = CUtensorMap; +#else + using TmaDescriptor = struct alignas(64) { char bytes[128]; }; + using Im2ColTmaDescriptor = struct alignas(64) { char bytes[128]; }; +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Initiates a TensorMap Prefetch +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param) + asm volatile ( + "prefetch.tensormap [%0];" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a TensorMap modification (by each field) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Replace tensor pointer directly in GMEM +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;" + :: "l"(gmem_int_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + uint64_t const smem_int64_desc = 0; + asm volatile ( + "cvt.u64.u32 %0, %1;" + :: "l"(smem_int64_desc), "r"(smem_int_desc)); + asm volatile ( + "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" + :: "l"(smem_int64_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc, + cute::array const& prob_shape, + cute::array const& prob_stride) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const smem_int64_desc = 0; + asm volatile ( + "cvt.u64.u32 %0, %1;" + :: "l"(smem_int64_desc), "r"(smem_int_desc)); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[0])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[1])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[2])); + // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + asm volatile ( + "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" + :: "l"(gmem_int_desc), "r"(smem_int_desc)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a release fence operation (needed when modifying tensormap directly in GMEM) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_release() +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + asm volatile ("fence.proxy.tensormap::generic.release.gpu;"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a acquire fence operation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" + : + : "l"(gmem_int_desc) + : "memory"); + asm volatile ( + "cvta.global.u64 %0, %0;" + : + : "l"(gmem_int_desc), "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/copy_sm90_tma.hpp b/server/punica_kernels/include/cutlass/cute/arch/copy_sm90_tma.hpp new file mode 100644 index 00000000..1136c433 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/copy_sm90_tma.hpp @@ -0,0 +1,1360 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.1d.L2.global" + " [%0, {%1}];" + : + : "l"(gmem_int_desc), + "r"(crd0) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.2d.L2.global" + " [%0, {%1, %2}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global" + " [%0, {%1, %2, %3}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global" + " [%0, {%1, %2, %3, %4}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global" + " [%0, {%1, %2, %3, %4, %5}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::PREFETCH::copy(desc_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::PREFETCH::copy(desc_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3, crd4); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2], {%6};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global.im2col" + " [%0, {%1, %2, %3}], {%4};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global.im2col" + " [%0, {%1, %2, %3, %4}], {%5, %6};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global.im2col" + " [%0, {%1, %2, %3, %4, %5}], {%6, %7, %8};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5, %6}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE : Initiates a TMA copy from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE im2col: Initiates a TMA copy, in im2col mode, from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_3D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_4D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_5D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n); + } +}; + +// Fence for smem stores for subsequent TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_fence() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile ("fence.proxy.async.shared::cta;"); +#elif defined(__CUDA_ARCH__) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Indicate arrival of warp issuing TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_arrive() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete +template +CUTE_HOST_DEVICE static void +tma_store_wait() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(Count) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_REDUCE_ADD : Initiates a TMA reduce-add from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_REDUCE_ADD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_REDUCE_ADD_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_REDUCE_ADD_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_REDUCE_ADD_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_REDUCE_ADD_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_REDUCE_ADD_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// BULK_COPY : Copy a bulk of memory between shared memory and global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_BULK_COPY_G2S +{ + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, uint64_t* mbar_ptr, + void * smem_ptr, int32_t load_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(smem_int_ptr), "l"(gmem_ptr), "r"(load_bytes), "r"(smem_int_mbar) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, int32_t load_bytes) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" + : + : "l"(gmem_ptr), "r"(load_bytes) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_BULK_COPY_S2G +{ + CUTE_HOST_DEVICE static void + copy(void const* smem_ptr, + void * gmem_ptr, int32_t store_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_BULK_COPY_AUTO {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma.hpp new file mode 100644 index 00000000..5bfda746 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +// +// Direct FMA for any type +// + +template +struct UniversalFMA +{ + using DRegisters = D[1]; + using ARegisters = A[1]; + using BRegisters = B[1]; + using CRegisters = C[1]; + + CUTE_HOST_DEVICE static constexpr void + fma(D & d, + A const& a, + B const& b, + C const& c) + { + // Forward to an ADL/cute free function for these types + using cute::fma; + fma(d, a, b, c); + } +}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma_sm61.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma_sm61.hpp new file mode 100644 index 00000000..f7bcb7d1 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma_sm61.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) +# define CUTE_ARCH_MMA_SM61_ENABLED +#endif + +namespace cute +{ + +struct SM61_DP4A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +struct SM61_DP2A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma_sm70.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma_sm70.hpp new file mode 100644 index 00000000..63d96cf5 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma_sm70.hpp @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +# define CUTE_ARCH_MMA_SM70_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_ARCH_MMA_SM70_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM70 MMA 884 F16F16F16 +// + +struct SM70_8x8x4_F16F16F16F16_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_TT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM70 MMA 884 F16F16F32 +// + +struct SM70_8x8x4_F32F16F16F32_TN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_TT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma_sm75.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma_sm75.hpp new file mode 100644 index 00000000..c33f7b39 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma_sm75.hpp @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) +# define CUTE_ARCH_MMA_SM75_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +# define CUTE_ARCH_MMA_SM75_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM75 MMA 1688 F16F16F32 +// + +struct SM75_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM75 MMA 8816 S8S8S32 +// + +struct SM75_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32" + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma_sm80.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma_sm80.hpp new file mode 100644 index 00000000..8c684b70 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma_sm80.hpp @@ -0,0 +1,2144 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTE_ARCH_MMA_B1_AND_SM80_ENABLED +#endif + +#if (__CUDA_ARCH__ <= 890) +#define CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + +#endif + + + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM80_16x8x4_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x4 TN +struct SM80_8x8x4_F64F64F64F64_TN +{ + using DRegisters = double[2]; + using ARegisters = double[1]; + using BRegisters = double[1]; + using CRegisters = double[2]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, + double const& a0, + double const& b0, + double const& c0, double const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=d"(d0), "=d"(d1) + : "d"(a0), + "d"(b0), + "d"(c0), "d"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +// MMA 8x8x4 TN with Planar Complex multiplication +struct SM80_8x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = complex[2]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex const& a0, + complex const& b0, + complex const& c0, complex const& c1) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + a0.real(), + b0.real(), + c0.real(), c1.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.imag(), + b0.real(), + c0.imag(), c1.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + -a0.imag(), + b0.imag(), + d0.real(), d1.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.real(), + b0.imag(), + d0.imag(), d1.imag()); + } +}; + +// MMA 8x8x4 TN with Gaussian Complex multiplication: +// (a + bi)*(c + di) +// yields +// t0 += a*c +// t1 += b*d +// t2 += (a+b)*(c+d) +// then +// re = t0 - t1 +// im = t2 - t0 - t1 +struct SM80_8x8x4_GC64C64C64GC64_TN +{ + struct GaussComplex { + double t0, t1, t2; + + CUTE_HOST_DEVICE //constexpr + operator complex() const { return complex(t0 - t1, t2 - t0 - t1); } + + CUTE_HOST_DEVICE friend //constexpr + complex operator*(GaussComplex const& a, complex const& b) { return static_cast>(a) * b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator*(complex const& a, GaussComplex const& b) { return b * a; } + + CUTE_HOST_DEVICE friend //constexpr + complex operator+(GaussComplex const& a, complex const& b) { return static_cast>(a) + b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator+(complex const& a, GaussComplex const& b) { return b + a; } + }; + + using DRegisters = GaussComplex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = GaussComplex[2]; + + CUTE_HOST_DEVICE static void + fma(GaussComplex & d0, GaussComplex & d1, + complex const& a0, + complex const& b0, + GaussComplex const& c0, GaussComplex const& c1) + { + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t0, d1.t0, + a0.real(), + b0.real(), + c0.t0, c1.t0); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t1, d1.t1, + a0.imag(), + b0.imag(), + c0.t1, c1.t1); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t2, d1.t2, + a0.real() + a0.imag(), + b0.real() + b0.imag(), + c0.t2, c1.t2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma_sm90.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma_sm90.hpp new file mode 100644 index 00000000..10bed48a --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma_sm90.hpp @@ -0,0 +1,1402 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +// Config +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +# define CUTE_ARCH_MMA_SM90_ENABLED +# define CUTE_ARCH_MMA_F64_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM90_16x8x4_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[2]; + using BRegisters = double[1]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, + double const& b0, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), + "d"(b0), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM90_16x8x8_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[4]; + using BRegisters = double[2]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& b0, double const& b1, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(b0), "d"(b1), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM90_16x8x16_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[8]; + using BRegisters = double[4]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& a4, double const& a5, double const& a6, double const& a7, + double const& b0, double const& b1, double const& b2, double const& b3, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7, %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16, %17, %18, %19};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(a4), "d"(a5), "d"(a6), "d"(a7), + "d"(b0), "d"(b1), "d"(b2), "d"(b3), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM90_16x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[2]; + using BRegisters = complex[1]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& b0, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM90_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), + b0.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM90_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), + b0.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM90_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), + b0.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM90_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), + b0.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM90_16x8x8_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[4]; + using BRegisters = complex[2]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& b0, complex const& b1, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM90_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.real(), b1.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM90_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + b0.real(), b1.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM90_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + b0.imag(), b1.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM90_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.imag(), b1.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM90_16x8x16_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[8]; + using BRegisters = complex[4]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& a4, complex const& a5, + complex const& a6, complex const& a7, + complex const& b0, complex const& b1, + complex const& b2, complex const& b3, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM90_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM90_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + a4.imag(), a5.imag(), a6.imag(), a7.imag(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM90_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM90_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { +namespace GMMA { + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // FP16 accumulator + if constexpr (is_same_v) { + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + // Dispatch against the Tile N mode size + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F16E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F16E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F16E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F16E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // FP32 accumulator + else if constexpr (is_same_v) { + + // FP16 inputs + if constexpr (is_same_v) { + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // BF16 inputs + else if constexpr (is_same_v) { + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // TF32 inputs + else if constexpr (is_same_v) { + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x8_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + // ElementA == int8_t && ElementB == int8_t + if constexpr (is_same_v && is_same_v) { + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == int8_t && ElementB == uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // FP16 accumulator + if constexpr (is_same_v) { + static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + // Dispatch against the Tile N mode size + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP32 accumulator + else if constexpr (is_same_v) { + + // FP16 inputs + if constexpr (is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // BF16 inputs + else if constexpr (is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x16_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // TF32 inputs + else if constexpr (is_same_v) { + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x8_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + // ElementA == int8_t && ElementB == int8_t + if constexpr (is_same_v && is_same_v) { + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == int8_t && ElementB == uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // ElementA == uint8_t && ElementB == uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} +} // end namespace GMMA +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma_sm90_desc.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma_sm90_desc.hpp new file mode 100644 index 00000000..a6cb1943 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma_sm90_desc.hpp @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA Descriptor and utilities + +// GMMA enums and utilities +namespace GMMA +{ + +enum class LayoutType : uint8_t { + INTERLEAVE = 0, + B128 = 1, + B64 = 2, + B32 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::INTERLEAVE: return "INTERLEAVE"; + case LayoutType::B128: return "B128"; + case LayoutType::B64: return "B64"; + case LayoutType::B32: return "B32"; + } + return nullptr; +} + +#if !defined(__CUDACC_RTC__) +// Output operator for all enums in this namespace +CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { + char const* s = to_string(t); + if (s) { + std::operator<<(os, s); // Explicit call to avoid ambiguity + } else { + os.setstate(std::ios_base::failbit); + } + return os; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace GMMA + +union GmmaDescriptor +{ + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor && t) noexcept : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor && t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + } bitfield; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } + + // Printer + CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t) + { + #if !defined(__CUDACC_RTC__) + printf("GmmaDescriptor: 0x%016llx\n", static_cast(t.desc_)); + printf(" start_addr : 0x%04x\n", t.bitfield.start_address_); + printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); + printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); + printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); + #endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cute/arch/mma_sm90_gmma.hpp b/server/punica_kernels/include/cutlass/cute/arch/mma_sm90_gmma.hpp new file mode 100644 index 00000000..bdf0d70e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/mma_sm90_gmma.hpp @@ -0,0 +1,20639 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Warpgroup sync primitives + +CUTE_HOST_DEVICE +void +warpgroup_arrive() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +template +CUTE_HOST_DEVICE +void +warpgroup_wait() +{ + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +// Marks the commit point for one or more sized batch of warpgroup MMAs. +CUTE_HOST_DEVICE +void +warpgroup_commit_batch() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(uint32_t& reg) { + // MSVC emits a build error for 'asm volatile' + // even if it only occurs in a __device__ function. + // This prevents the error. +#if defined(__CUDA_ARCH__) + asm volatile("" : "+r"(reg) :: "memory"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(float& reg) { +#if defined(__CUDA_ARCH__) + asm volatile("" : "+f"(reg) :: "memory"); +#endif +} + +namespace GMMA { + +enum class Major { + K = 0, + MN = 1 +}; + +enum class ScaleOut { + Zero = 0, + One = 1 +}; + +enum class ScaleIn { + Neg = -1, + One = 1 +}; + +} // namespace GMMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct SM90_64x8x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct SM90_64x16x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct SM90_64x32x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct SM90_64x64x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct SM90_64x96x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct SM90_64x128x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct SM90_64x192x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct SM90_64x256x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct SM90_64x8x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct SM90_64x16x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct SM90_64x32x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct SM90_64x64x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct SM90_64x96x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct SM90_64x128x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct SM90_64x192x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct SM90_64x256x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct SM90_64x8x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct SM90_64x16x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct SM90_64x32x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct SM90_64x64x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct SM90_64x96x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct SM90_64x128x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct SM90_64x192x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct SM90_64x256x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct SM90_64x8x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct SM90_64x16x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct SM90_64x32x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct SM90_64x64x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct SM90_64x96x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct SM90_64x128x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct SM90_64x192x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct SM90_64x256x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct SM90_64x8x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct SM90_64x16x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct SM90_64x32x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct SM90_64x64x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct SM90_64x96x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct SM90_64x128x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct SM90_64x192x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct SM90_64x256x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct SM90_64x8x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct SM90_64x16x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct SM90_64x32x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct SM90_64x64x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct SM90_64x96x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct SM90_64x128x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct SM90_64x192x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct SM90_64x256x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct SM90_64x8x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct SM90_64x16x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct SM90_64x32x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct SM90_64x64x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct SM90_64x96x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct SM90_64x128x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct SM90_64x192x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct SM90_64x256x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct SM90_64x8x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct SM90_64x16x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct SM90_64x32x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct SM90_64x64x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct SM90_64x96x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct SM90_64x128x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct SM90_64x192x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct SM90_64x256x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/arch/util.hpp b/server/punica_kernels/include/cutlass/cute/arch/util.hpp new file mode 100644 index 00000000..2d95ec41 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/arch/util.hpp @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#if defined(__clang__) && defined(__CUDA__) + // __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665 + #if __clang_major__ >= 14 + #define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 + // ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 + #if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15 + #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // __cvta_generic_to_shared added in CUDA 11+ + #if __CUDACC_VER_MAJOR__ >= 11 + #define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in CUDA 10.2 + #if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 + #define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED + #define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1 +#endif + +#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1 +#endif + +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER + #define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1 +#endif + +#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1 +#endif + +// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need +// to provide one for NVCC +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER + extern "C" { + // This NVVM intrinsic is subject to change in future versions of CUDA. + // Clients should not call it directly. + CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); + } +#endif + +namespace cute +{ + +/// CUTE helper to cast SMEM pointer to unsigned +CUTE_DEVICE +uint32_t +cast_smem_ptr_to_uint(void const* const ptr) +{ +// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to +// the previous internal intrinsics if they are available. +#if CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTE helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm( + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) : "l"(ptr)); + + return smem_ptr; + +#else + + + (void) ptr; + printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); + return 0; + +#endif +} + +// +// Utility for pointer interfaces +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrS&& s, int_sequence, + PtrD&& d, int_sequence) +{ + return fn(s[Is]..., d[Id]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_with_d_scaling(Fn fn, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + ParamType&& p0) +{ + return fn(a[Ia]..., b[Ib]..., c[Ic]..., p0); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_with_d_scaling(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + ParamType&& p0) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., p0); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, PtrS&& s, PtrD&& d) +{ + return detail::explode(fn, + s, make_int_sequence{}, + d, make_int_sequence{}); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c) +{ + return detail::explode(fn, + a, make_int_sequence{}, + b, make_int_sequence{}, + c, make_int_sequence{}); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c) +{ + return detail::explode(fn, + d, make_int_sequence{}, + a, make_int_sequence{}, + b, make_int_sequence{}, + c, make_int_sequence{}); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_atom.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_atom.hpp new file mode 100644 index 00000000..d1cd3d4b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_atom.hpp @@ -0,0 +1,772 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include + +#include + +namespace cute +{ + +template +struct Copy_Atom; + +template +struct Copy_Atom : Copy_Atom, CopyInternalType> +{}; + +template +struct Copy_Atom, CopyInternalType> + : Copy_Traits +{ + using Traits = Copy_Traits; + + // Bit and Thr layouts from the Copy_Traits + using ThrID = typename Traits::ThrID; + using BitLayoutSrc = typename Traits::SrcLayout; + using BitLayoutDst = typename Traits::DstLayout; + using BitLayoutRef = typename Traits::RefLayout; + + using ValType = CopyInternalType; + + using ValLayoutSrc = decltype(recast_layout(BitLayoutSrc{})); + using ValLayoutDst = decltype(recast_layout(BitLayoutDst{})); + using ValLayoutRef = decltype(recast_layout(BitLayoutRef{})); + + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); + + static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); + static constexpr int NumValDst = size<1>(ValLayoutDst{}); + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return Copy_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Check and call instruction, or recurse + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor & dst) const + { + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || + is_constant::value) { + // Dispatch to unpack to execute instruction + return copy_unpack(*this, src, dst); + } else + if constexpr (is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) + return copy(*this, tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(dependent_false, "No instruction match and no recursion possible."); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor && dst) const + { + return call(src, dst); + } +}; + +// +// A tiling of copy atoms +// + +template +struct ThrCopy; + +template coord [Need not be 2D...] + class ShapeTiler_MN> // coord space +struct TiledCopy : Copy_Atom +{ + // Layout information from the CopyAtom + using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx + using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset + using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset + using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset + + using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); + using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); + + // Layout information for the TiledCopy + using Tiler_MN = ShapeTiler_MN; + using TiledLayout_TV = LayoutCopy_TV; + using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); + using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); + + CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); + CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // where + // ThrV: The threads local to a COPY_ATOM Src. + // ThrX: The threads tiled across COPY_ATOMs Src. + // FrgV: The values local to a COPY_ATOM Src. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_S(STensor&& stensor) + { + CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // where + // ThrV: The threads local to a COPY_ATOM Dst. + // ThrX: The threads tiled across COPY_ATOMs Dst. + // FrgV: The values local to a COPY_ATOM Dst. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_D(DTensor&& dtensor) + { + CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); + } + + // Tile a tensor or a layout from shape + // ((TileM,TileN,...), (RestM,RestN,...)) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + template + CUTE_HOST_DEVICE constexpr static + auto + tile2thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) + { + // Take the thrs/vals that the atom is interested in + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); + // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) + + // Transform to the trg layout + auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); + // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) + + // Transform the thrs mode from thrid to thr_idx + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); + // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) + + /// ================== + + // Transform the tile mode + auto tv_tensor = tensor.compose(thrval2mn, _); + // ((thrid,val),(RestM,RestN,...)) + + // Unfold and return + return tv_tensor(make_coord(_,_), _); + } + + // retile_S and retile_D assume they are working with the reference layout -- they are the same + template + CUTE_HOST_DEVICE constexpr static + auto + retile(Tensor&& tensor) + { + constexpr int R = remove_cvref_t::rank; + // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation + + // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. + // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV + // and that shape is what we gather from the other modes of tensor + + auto V = size<0>(tensor); + + auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); + // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV + + auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); + // (atom_vals,rest_vals) -> (v,m,n) + + /// ======= + + // Tile the tensor for TileFrg + auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); + // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) + + // Transform the tile mode + auto v_tensor = t_tensor.compose(frg_layout_v, _); + // ((atom_vals,rest_vals),(1,RM,RN,...)) + + // Unfold and return + return v_tensor(_, append(Int<0>{},_)); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_TV() + { + // (M,N) -> (M,N) + auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_MN() + { + // (thr_idx,val_idx) -> (M,N) + auto layoutS_TV = get_layoutS_TV(); + // (M,K) -> (thr_idx,val_idx) + auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{})); + + // athrid = (v,m,k) -> thr_idx + auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); + + return cute::make_tuple(layoutS_MK, thrID_S); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_TV() + { + // (M,N) -> (M,N) + auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_MN() + { + // (thr_idx,val_idx) -> (M,N) + auto layoutD_TV = get_layoutD_TV(); + // (M,K) -> (thr_idx,val_idx) + auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{})); + + // athrid = (v,m,k) -> thr_idx + auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); + + return cute::make_tuple(layoutD_MK, thrID_D); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_slice(ThrIdx const& thr_idx) + { + return ThrCopy(thr_idx); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_thread_slice(ThrIdx const& thr_idx) + { + return get_slice(thr_idx); + } +}; + +template +struct ThrCopy +{ + ThrIdx thr_idx_; + + CUTE_HOST_DEVICE + ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} + + template + CUTE_HOST_DEVICE + auto + partition_S(STensor&& stensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + auto thr_tensor = make_tensor(static_cast(stensor).data(), TiledCopy::tidfrg_S(stensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE + auto + partition_D(DTensor&& dtensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + auto thr_tensor = make_tensor(static_cast(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE static + auto + retile_S(STensor&& stensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + return make_tensor(static_cast(stensor).data(), TiledCopy::retile(stensor.layout())); + } + + template + CUTE_HOST_DEVICE static + auto + retile_D(DTensor&& dtensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + return make_tensor(static_cast(dtensor).data(), TiledCopy::retile(dtensor.layout())); + } +}; + + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_impl(Copy_Atom const& atom, + LayoutCopy_TV const&, + Tiler const&) +{ + return TiledCopy, LayoutCopy_TV, Tiler>{atom}; +} + +// +// These tile the Copy_Atom as a whole +// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_B(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma))); +} + +// returns the smallest tiled copy that can retile LayoutC_TV +// for use with pipelined epilogues with subtiled stores +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_atom(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + // Truncate the V-layout to just the Copy_Atom, keep the V-order + auto layoutC_TV = mma.get_layoutC_TV(); + auto copy_V = Int::NumValSrc>{}; + CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); + auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); + + // Recompute tiler and restride the TV layout for the new tiler + + // Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + // Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma)); + auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(mma_tiler, replace(mma_zeros, Int<1>{})), layout_TV)); + }); + + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // Apply the tiler to a reference and transform the codomain + // tile_coord -> mma_coord + auto tile2mma = composition(make_layout(mma_tiler), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2mma), layout_TV); + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from logical thread and values layouts. + * The thread and value layouts map coordinates to thr_idx and val_idx. + * The product of these layouts is taken to produce the TV layout and the Tiler. + * Useful when threads and values need very specific mappings onto coordinates + * in the target tensors. + */ +template > +CUTE_HOST_DEVICE +auto +make_tiled_copy(Copy_Atom const& copy_atom, + ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx + ValLayout const& val_layout = {}) // (m,n) -> val_idx +{ + // Take the raked_products to compute the Layout_MN + // (M,N) -> (thr_idx, val_idx) + auto layout_mn = raked_product(thr_layout, val_layout); + // (thr_idx, val_idx) -> (M,N) + auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); + // Tiler for extracting relevant elements + // (M,N) -> tensor coord + auto tiler = product_each(shape(layout_mn)); + +#if 0 + print("thr_layout: "); print(thr_layout); print("\n"); + print("val_layout: "); print(val_layout); print("\n"); + print("layout_mn : "); print(layout_mn); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + print("tiler : "); print(tiler); print("\n"); +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from thread and value offset maps. + * The TV Layout maps threads and values to the codomain of the data_layout. + * It is verified that the intended codomain is valid within data_layout. + * Useful when threads and values don't care about owning specific coordinates, but + * care more about the vector-width and offsets between them. + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_cotiled_copy(Copy_Atom const& copy_atom, + AtomTVLayout const& atom_tv_layout, // atom (thr,val) -> data addr + DataLayout const& data_layout) // coord -> data addr The target layout +{ + static_assert(is_static::value); + static_assert(is_static::value); + + // data addr -> data coord Append 1:0 so off-the-ends get the stride-0 + auto inv_data_layout = make_layout(left_inverse(data_layout), Layout<_1,_0>{}); + + // (tid,vid) -> data_coord + auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); + + // Check validity + CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); + +#if 0 + if (thread0()) { + print("data_layout : "); print(data_layout); print("\n"); + print("atom_tv_layout : "); print(atom_tv_layout); print("\n"); + print("layout_tv_data : "); print(layout_tv_data); print("\n"); + } +#endif + + // + // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled DATA + auto flat_data_shape = product_each(shape(data_layout)); + auto flat_data_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_data_shape, replace(flat_data_zeros, Int<1>{})), layout_tv_data)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> data_coord + auto tile2data = composition(make_layout(flat_data_shape), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2data), layout_tv_data); + +#if 0 + if (thread0()) { + print("tiler : "); print(tiler); print("\n"); + print("tile2data : "); print(tile2data); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + } +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_S(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); +} + +// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_D(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); +} + +// +// Size +// + +// The logical size of a TileCopy +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledCopy const&) +{ + return size(typename TiledCopy::Tiler_MN{}); +} + +// The number of threads involved in a TiledCopy +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledCopy const&) +{ + return typename TiledCopy::TiledNumThr{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, T> const&) +{ + using Atom = Copy_Atom, T>; + print("Copy_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); + print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); + print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); + print(" ValueType: "); print(sizeof_bits::value); print("b\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledCopy const& copy, char const* pad = "") +{ + using Copy = TiledCopy; + print("TiledCopy\n"); + print(" Tiler_MN: "); print(typename Copy::Tiler_MN{}); print("\n"); + print(" TiledLayout_TV: "); print(typename Copy::TiledLayout_TV{}); print("\n"); + print(static_cast(copy)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrCopy const& thr_copy) +{ + print("ThrCopy\n"); + print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n"); + print(TiledCopy{}); +} + +template +CUTE_HOST_DEVICE +auto +print_latex(TiledCopy const& copy) +{ + auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); + auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); + + print_latex_copy(layoutS_MN, thrID_S, + layoutD_MN, thrID_D); +} + +// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); + + assert(size<0>(S) == size<0>(D)); + assert(size<1>(S) == size<1>(D)); + + char const* latex_header = + "\\documentclass{standalone}\n" + "\\usepackage{tikz}\n" + "\\usetikzlibrary{external}\n" + "\\tikzexternalize\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}",}; + + // Header + printf("%% LayoutS: "); print(S); printf("\n"); + printf("%% ThrIDS : "); print(TS); printf("\n"); + printf("%% LayoutD: "); print(D); printf("\n"); + printf("%% ThrIDD : "); print(TD); printf("\n\n"); + + printf(latex_header); + + // S starting at 0,0 + for (int i = 0; i < size<0>(S); ++i) { + for (int j = 0; j < size<1>(S); ++j) { + int thrid = S(i,j) % size(TS); + int val_idx = S(i,j) / size(TS); + int thr_idx = TS(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + i, j, + thr_idx, val_idx); + } + } + + // D starting at 0,size<1>(S)+3 + for (int i = 0; i < size<0>(D); ++i) { + for (int j = 0; j < size<1>(D); ++j) { + int thrid = D(i,j) % size(TD); + int val_idx = D(i,j) / size(TD); + int thr_idx = TD(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + i, j + size<1>(S) + 3, + thr_idx, val_idx); + } + } + + // S Labels + for (int i = 0, j = -1; i < size<0>(S); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int j = 0, i = -1; j < size<1>(S); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + // D Labels + for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); + } + for (int j = 0, i = -1; j < size<1>(D); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); + } + + // Footer + printf(latex_footer); +} + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include + +// Config +#if (__CUDACC_VER_MAJOR__ >= 12) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +#endif + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +#include +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_traits.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_traits.hpp new file mode 100644 index 00000000..b6259b58 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_traits.hpp @@ -0,0 +1,200 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +/** + * concept Copy_Traits + * { + * using ThrID = // Logical thread id (tid) -> tidx + * + * using SrcLayout = // (Logical src thread id (tid), Logical src value id (vid)) -> bit + * using DstLayout = // (Logical dst thread id (tid), Logical dst value id (vid)) -> bit + * using RefLayout = // (Logical ref thread id (tid), Logical ref value id (vid)) -> bit + * }; + * + * The abstract bit ordering of the Copy_Traits (the codomain of SrcLayout, DstLayout, and RefLayout) + * is arbitrary and only used to construct maps + * (ref-tid,ref-vid) -> (src-tid,src-vid) + * (ref-tid,ref-vid) -> (dst-tid,dst-vid) + * in TiledCopy. The Layout_TV in TiledCopy is in accordance with the RefLayout of a Traits, then mapped to + * the Src or Dst (tid,vid) representation on demand. + * + */ + +template +struct Copy_Traits +{ + static_assert(dependent_false, "Copy_Traits not implemented for this CopyOperation."); +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +namespace detail { + +// Utility for exploding pointers, arrays, or tensors into Operation::copy +template +CUTE_HOST_DEVICE constexpr +void +copy_explode_index(PtrSrc&& s, int_sequence, + PtrDst&& d, int_sequence) +{ + return Operation::copy(s[Is]..., d[Id]...); +} + +// Utility for exploding tuples into ::copy +template +CUTE_HOST_DEVICE constexpr +void +copy_explode(TupleArg&& t, int_sequence) +{ + return Operation::copy(get(static_cast(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +void +copy_explode(TupleSrc&& s, int_sequence, + TupleDst&& d, int_sequence) +{ + return Operation::copy(get(static_cast(s))..., + get(static_cast(d))...); +} + +template +CUTE_HOST_DEVICE constexpr +void +copy_explode(TupleAux&& a, int_sequence, + TupleSrc&& s, int_sequence, + TupleDst&& d, int_sequence) +{ + return Operation::copy(get(static_cast(a))..., + get(static_cast(s))..., + get(static_cast(d))...); +} + +} // end namespace detail + +// +// Generic copy_unpack for common argument-based Copy_Traits +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const&, + Tensor const& src, + Tensor & dst) +{ + // Specializations can generalize on these checks + //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); + //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); + + using RegistersSrc = typename CopyOp::SRegisters; + using RegistersDst = typename CopyOp::DRegisters; + using RegTypeSrc = typename remove_extent::type; + using RegTypeDst = typename remove_extent::type; + constexpr int RegNumSrc = extent::value; + constexpr int RegNumDst = extent::value; + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "Copy_Traits: src failed to vectorize into registers. Layout is incompatible with this CopyOp."); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp."); + + detail::copy_explode_index(rS, make_int_sequence{}, + rD, make_int_sequence{}); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor && dst) +{ + copy_unpack(traits, src, dst); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm75.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm75.hpp new file mode 100644 index 00000000..9ad82c61 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm75.hpp @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2>>, + Stride,Stride< _1,_128>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _2>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _4>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm80.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm80.hpp new file mode 100644 index 00000000..e5ff0b7b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm80.hpp @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value that determines whether to load or zfill + bool pred = false; + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEALWAYS_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value that determines whether to load or zfill + bool pred = false; + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEGLOBAL_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Element copy selector +template +CUTE_HOST_DEVICE constexpr +auto +select_elementwise_copy(SrcTensor const&, DstTensor const&) +{ + using SrcType = typename SrcTensor::value_type; + using DstType = typename DstTensor::value_type; + +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (is_gmem::value && is_smem::value && + sizeof(SrcType) == sizeof(DstType) && + (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16)) + { + return SM80_CP_ASYNC_CACHEALWAYS{}; + } else { + return UniversalCopy{}; + } + + CUTE_GCC_UNREACHABLE; +#else + return UniversalCopy{}; +#endif +} + +} diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90.hpp new file mode 100644 index 00000000..f9590848 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_im2col.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_im2col.hpp new file mode 100644 index 00000000..34e71ed6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_im2col.hpp @@ -0,0 +1,879 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief im2col make_tma_copy +*/ + +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/tensor.hpp" + +#include "cute/algorithm/prefetch.hpp" + +namespace cute +{ + +// Utility for unpacking TMA_LOAD_IM2COL arguments into a CopyOp +template +struct TMA_LOAD_IM2COL_Unpack +{ + /// Copy from src to dst. + /// + /// @param traits Copy traits created with a TMA descriptor that + /// correctly matches the input tensor and other convolution + /// parameters. + /// + /// @param src Tile of the im2col-transformed coordinate tensor + /// (result of get_tma_tensor), representing the global-memory + /// tensor from which to load. + /// + /// @param dst Shared memory tile, into which to load. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, // tile of the transformed global activation (A) tensor + Tensor & dst) // shared memory tile + { + auto src_coord_offset = src(Int<0>{}); + auto src_coord_cwhdn_offset_srt = flatten(src_coord_offset); + // Interpret the TMA IM2COL coordinate as (c, ([w,h,d]), n, ([s,r,t])) + CUTE_STATIC_ASSERT_V(rank(src_coord_offset) == _4{}); + CUTE_STATIC_ASSERT_V(rank<1>(src_coord_offset) == rank<3>(src_coord_offset)); + + if constexpr (detail::is_prefetch) { + return detail::copy_explode(traits.opargs_, tuple_seq{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } else { + static_assert(is_smem::value, "SM90_TMA_LOAD_IM2COL requires the destination be shared memory."); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); + return detail::copy_explode(traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } + } +}; + +// Copy_Traits for SM90 im2col TMA load comes in two layers. +// +// 1. Copy_Traits +// 2. Copy_Traits +// +// Copy_Traits +// is the "outer" layer. It has a TMA descriptor, +// but no barrier ("tma_mbar"), so it's "nonexecutable." +// One calls its "with" member function with a barrier, +// to get an executable "inner"-layer +// Copy_Traits object. +// That object's "copy_unpack" member function +// actually invokes im2col TMA load. + +struct SM90_TMA_LOAD_IM2COL_OP : SM90_TMA_LOAD_IM2COL {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL are not + /// directly executable. Instead, call this "with" member function + /// to get an executable specialization. "Executable" means that + /// @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (unused; only exists + /// for interface compatibility with the actual multicast Copy_Traits) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const + { + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL + // are not directly executable. Instead, call .with + // to get an executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 im2col +/// TMA load, with TMA descriptor and barrier. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL::PREFETCH arguments + tuple const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) + : opargs_({&traits.tma_desc_}) {} +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_OP : SM90_TMA_LOAD_IM2COL_MULTICAST {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier or multicast +/// mask. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + /// are not directly executable. Instead, call this "with" member + /// function to get an executable specialization. "Executable" + /// means that @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (defaults to a single CTA) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const { + return {{}, {&tma_desc_, &tma_mbar, multicast_mask}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + // are not directly executable. Instead, call .with to get an + // executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 multicast +/// im2col TMA load, with TMA descriptor, barrier, and multicast mask. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit. + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL_MULTICAST arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE IM2COL//////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_STORE_IM2COL with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE_IM2COL arguments + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE_IM2COL"); + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = flatten(take<0,3>(dst(Int<0>{}))); + + return detail::copy_explode(make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +namespace detail { + +/// @brief Creates a TMA descriptor for im2col TMA load. +/// +/// @param tensor_cwhdn Global activation tensor (A matrix of Fprop). +/// This is the original (not im2col-transformed) tensor in global +/// memory. +/// +/// @param slayout Rank 2 (M,K) shared memory layout of the activation +/// tensor. Here, K is "GEMM K," not the filter tensor's mode of +/// the same name. +////// +/// @param traversal_stride Traversal strides convolution parameter +////// +/// Each of padding_shape, traversal_stride, and dilation_shape is a +/// tuple whose size is the number of spatial modes (e.g., 3 for a 5-D +/// convolution). +/// +/// @return TMA descriptor for im2col TMA load +template +CUTE_HOST +auto +make_im2col_tma_copy_desc( + Tensor const& tensor_cwhdn, // (C,W,H,D,N) + uint32_t range_c, // TILE_C + uint32_t range_whdn, // TILE_WHDN + SmemSwizzle const& smem_swizzle, // Swizzle + TMALayout const& tma_layout_vt, // TMA layout + LowerCornerStride const& lower_corner_whd, // WHD offset of the "base pointer" + UpperCornerStride const& upper_corner_whd, // WHD upper corner + LowerPaddingStride const& lower_padding_whd, // WHD lower padding + UpperPaddingStride const& upper_padding_whd, // WHD upper padding + TraversalStride const& stride_whd, // WHD traversal stride + LowerSRTStride const& lower_srt, // SRT offset of the "base pointer" + DilationStride const& stride_srt) // SRT stride - dilation +{ + static_assert(is_gmem::value, "Tensor must point to GPU global memory."); + using value_type = typename EngineA::value_type; + + constexpr uint32_t num_total_modes = LayoutA::rank; + constexpr int num_spatial_modes = num_total_modes - 2; + + // Gmem starting address + void* gmem_address = (void*) raw_pointer_cast(tensor_cwhdn.data()); + + // Gmem extents are just the tensor shape + cute::array gmem_prob_shape = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_shape[i] = static_cast(shape(tensor_cwhdn)); + }); + + // Gmem strides are byte strides of the activation tensor in CWHDN order + cute::array gmem_prob_stride = {0,0,0,0,0}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_stride[i] = sizeof(value_type) * stride(tensor_cwhdn); + }); + + // Traversal strides are a function of the dilation shape + // corresponding to spatial (WHD) modes. + cute::array tma_traversal_strides = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + tma_traversal_strides[i+1] = static_cast(get(stride_whd)); + }); + + cute::array tma_lower_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_lower_corner[i] = static_cast(get(lower_corner_whd)); + }); + + cute::array tma_upper_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_upper_corner[i] = static_cast(get(upper_corner_whd)); + }); + + Im2ColTmaDescriptor tma_desc; + +#if (__CUDACC_VER_MAJOR__ >= 12) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + CUtensorMapFloatOOBfill tma_oob_fill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle)); + + CUresult encode_result = cuTensorMapEncodeIm2col( + &tma_desc, + tma_format, + num_total_modes, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly sizeof(value_type) + tma_lower_corner.data(), + tma_upper_corner.data(), + range_c, + range_whdn, + tma_traversal_strides.data(), + tma_interleave, + tma_swizzle, + tma_l2Promotion, + tma_oob_fill); + + // The extra asserts help indicate the error's cause. + assert(encode_result != CUDA_ERROR_DEINITIALIZED); + assert(encode_result != CUDA_ERROR_NOT_INITIALIZED); + assert(encode_result != CUDA_ERROR_INVALID_CONTEXT); + assert(encode_result != CUDA_ERROR_INVALID_VALUE); + assert(encode_result == CUDA_SUCCESS); + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + // + // Calculate gemm shapes and linearized shapes based on tma layout tiling. + // + + // Compute [w, h, d, n] + // q/p/z = (w/h/d + (upper_corner_whd - lower_corner_whd - 1)) / stride_whd + 1 + auto gemm_mn_ = cute::transform(cute::make_seq{}, [&](auto i) { + return (shape(tensor_cwhdn) + get(upper_corner_whd) - get(lower_corner_whd) - Int<1>{}) / get(stride_whd) + Int<1>{}; + }); + auto gemm_mn = append(gemm_mn_, shape(tensor_cwhdn)); + + // Compute [c, s, r, t] + // fprop/wgrad, s/r/t = 1 + (upper_padding_whd - upper_corner_whd) / stride_srt + // wgrad, s/r/t = 1 + (lower_padding_whd - lower_corner_whd) / stride_srt + auto gemm_k_ = cute::transform(cute::make_seq{}, [&](auto i) { + auto padding_size = conditional_return(get(stride_srt) > Int<0>{}, + get(upper_padding_whd) - get(upper_corner_whd), + get(lower_corner_whd) - get(lower_padding_whd)); + return Int<1>{} + padding_size / get(stride_srt); + }); + auto gemm_k = prepend(gemm_k_, shape<0>(tensor_cwhdn)); + + // For fprop/dgrad kernel, gemm_shapes is ((q, p, z, n), (c, s, r, t)) + // For wgrad kernel, gemm_shapes is ((c, s, r, t), (q, p, z, n)) + auto gemm_shapes_common = make_shape(gemm_mn, gemm_k); + auto gemm_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), gemm_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), gemm_shapes_common)); + + // For fprop/dgrad kernel, linearized shapes is (whdn, (c, s, r, t)) + // For wgrad kernel linearized shapes is ((c, s, r, t), whdn) + auto linear_shapes_common = make_shape(size(gemm_mn), gemm_k); + auto linear_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), linear_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), linear_shapes_common)); + + // + // Calculate gmem basis stride based on tma layout tiling. + // + + auto tma_basis_scale = make_shape(Int<1>{}, stride_whd, Int<1>{}, stride_srt); + auto tma_basis = elem_scale(tma_basis_scale, make_basis_like(tma_basis_scale)); + + auto gbasis_strides_common = make_stride( + append(get<1>(tma_basis), get<2>(tma_basis)), + prepend(get<3>(tma_basis), get<0>(tma_basis))); // ((w,h,d,n),(c,s,r,t)) + auto gbasis_strides = make_stride( + basis_get(stride<0,1>(tma_layout_vt), gbasis_strides_common), + basis_get(stride<0,0>(tma_layout_vt), gbasis_strides_common)); + + // + // Create tma tensor + // + + auto lower_corner = make_arithmetic_tuple(Int<0>{}, lower_corner_whd, Int<0>{}, lower_srt); + + auto tensor_multimode = make_tensor(ArithmeticTupleIterator(lower_corner), gemm_shapes, gbasis_strides); + auto tensor_linear = make_identity_tensor(linear_shapes); + auto tma_tensor = make_tensor(tensor_multimode.data(), composition( + tensor_multimode.layout(), + tensor_linear(Int<0>{}), + tensor_linear.layout())); + + return cute::make_tuple(tma_desc, tma_tensor); +} + +/// Make a TiledCopy for im2col TMA load. +/// +/// @param copy_op The copy implementation: either +/// SM90_TMA_LOAD_IM2COL or SM90_TMA_LOAD_IM2COL_MULTICAST. +/// +/// @param tensor_cwhdn The global tensor to use for im2col TMA loads. +/// For Fprop convolutions, this is the activation tensor. This is +/// the "original tensor that points to global memory, not the +/// coordinate (im2col-transformed) tensor. +/// +/// @param slayout Layout of shared memory tile. +/// +/// @param stride_whd The traversal strides convolution +/// parameter. +/// +/// @return TiledCopy specialization for im2col TMA loads. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map, // CTA vid -> gmem coord + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) // dilation +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, + "Number of active CTAs in TMA must divide domain size of slayout."); + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // trunc_smem_idx -> trunc_smem_coord + + // Map from smem idx to a gmem mode + auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_t_map : "); print(cta_t_map); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // Generate a TupleBasis for the gtensor + auto glayout_basis = make_identity_layout(product_each(shape(gtensor))); + + // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc + auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank >= 2, "IM2COL expects at least 2 modes of the smem to vectorize with gmem."); + // IM2COL uses a maximum of 2 modes + constexpr int smem_tma_rank = cute::min(int(smem_rank), 2); + + // Keep only the static-1 basis modes into gmem + auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); + + // Split according to the portion each multicast CTA will be responsible for + auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map))); + +#if 0 + print("glayout_basis : "); print(glayout_basis); print("\n"); + print("tma_layout_full : "); print(tma_layout_full); print("\n"); + + print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); + print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); +#endif + + auto range_c = size<0,0>(tma_layout_vt); + auto range_whdn = size<0,1>(tma_layout_vt); + + Tensor gtensor_cwhdn = make_tensor(gtensor.data(), + flatten(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.layout()), + basis_get(stride<0,1>(tma_layout_vt), gtensor.layout())))); + + auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( + gtensor_cwhdn, + range_c, + range_whdn, + detail::get_swizzle_portion(slayout), + tma_layout_vt, + lower_corner_whd, + upper_corner_whd, + lower_padding_whd, + upper_padding_whd, + stride_whd, + lower_srt, + stride_srt); + + // + // Construct the Copy_Traits + // + + using T = typename GEngine::value_type; + constexpr int num_bits_per_tma = decltype(size<0>(tma_layout_vt))::value * sizeof(T) * 8; + + using Traits = Copy_Traits, decltype(tma_tensor)>; + +#if 0 + print("num_bits : "); print(NumBitsPerTMA{}); print("\n"); +#endif + + Traits tma_traits{tma_desc, tma_tensor}; + + // + // Construct the TiledCopy + // + + auto cta_tiler = product_each(shape(cta_v_map)); + + // (CTA V, CTA T) -> smem_coord + auto layout_vt = composition(inv_smem_layout, make_layout(shape(tma_layout_vt))); + // Scale that up to cover all of the smem_coords + // + // The smem vector might not cover all of the tile, + // so multiply it up to cover the entire tile. + // "T" here (the parallel index) is a CTA index. + auto layout_VT = tile_to_shape(layout_vt, make_shape(size(cta_v_map)/size<1>(layout_vt), size<1>(layout_vt))); + // Flip it and change the domain of the T from logical thr to thr_idx + auto layout_TV = make_layout(composition(layout<1>(layout_VT), cta_t_map), layout<0>(layout_VT)); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_VT : "); print(layout_VT); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + using T = typename GEngine::value_type; + return TiledCopy, decltype(layout_TV), decltype(cta_tiler)>{tma_traits}; +} + +/// Make a TiledCopy for im2col TMA with no offsets. +/// E.g. im2col TMA load for C and im2col TMA store for D. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map) // CTA vid -> gmem coord +{ + constexpr int num_spatial_modes = rank<0>(GLayout{}) - 1; + return make_tma_copy_im2col(copy_op, gtensor, slayout, cta_t_map, cta_v_map, + append(Stride<_0>{}, Int<0>{}), // lower_corner_whd + append(Stride<_0>{}, Int<0>{}), // upper_corner_whd + append(Stride<_0>{}, Int<0>{}), // lower_padding_whd + append(Stride<_0>{}, Int<0>{}), // upper_padding_whd + append(Stride<_1>{}, Int<1>{}), // stride_whd + append(Stride<_0>{}, Int<0>{}), // lower_srt + append(Stride<_1>{}, Int<1>{})); // stride_srt +} + +} // namespace detail + + + +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, + slayout, cta_t_tile, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// No offsets copy. +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, slayout, cta_t_tile, cta_v_tile); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}); +} + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_tma.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_tma.hpp new file mode 100644 index 00000000..8dba3ded --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_tma.hpp @@ -0,0 +1,1326 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include +#include + +#include + +#include + +namespace cute +{ + +template +struct AuxTmaParams { + using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic + GmemStrides g_stride_; + using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static + static_assert(is_static::value); + using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle + static_assert(is_static::value); +}; + +// Utility for unpacking TMA_LOAD arguments into a CopyOp +template +struct TMA_LOAD_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + auto src_coord = src.data().coord_; + if constexpr (detail::is_prefetch) { + return detail::copy_explode(traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); + } else { + static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); +#endif + return detail::copy_explode(traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); + } + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD /////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; + +// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar}}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +// The prefetch for SM90_TMA_LOAD with tma_desc +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD::PREFETCH arguments + tuple const opargs_; + + // Construct with any other Traits' TMA Desc + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) + : opargs_({&traits.tma_desc_}) {} +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; + +// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar +// Use .with(tma_mbar, multicast_mask) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask}}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask}}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_STORE with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::copy_explode(make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_REDUCE_ADD ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_REDUCE_ADD with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_REDUCE_ADD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + template + CUTE_HOST_DEVICE constexpr + void + copy_unpack_(void const* const src_ptr, + Coord const& dst_coord, seq) const + { +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + + SM90_TMA_REDUCE_ADD::copy(&tma_desc_, + src_ptr, get(dst_coord)...); + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_REDUCE_ADD"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_REDUCE_ADD"); // TMA spoofed src tensor + + traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// BULK COPY ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_BULK_COPY_G2S arguments + // 0: uint64_t* bulk_load_memory_barrier + cute::tuple bulk_load_mbar_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {{&bulk_mbar}}; + } + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_same, cute::tuple>::value, + "Extra arguments not set. Set .with() before use."); + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_COPY_G2S"); + static_assert(is_smem::value, "Expected smem dst for SM90_BULK_COPY_G2S"); + SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), get<0>(traits.bulk_load_mbar_), + raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits + : Copy_Traits +{ + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) {} + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_PREFETCH"); + SM90_BULK_COPY_G2S::PREFETCH::copy(raw_pointer_cast(src.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_BULK_COPY_S2G"); + static_assert(is_gmem::value, "Expected gmem dst for SM90_BULK_COPY_S2G"); + SM90_BULK_COPY_S2G::copy(raw_pointer_cast(src.data()), raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +// +// Placeholder for the bulk copy algorithm's default, auto-vectorizing behavior +// + +template +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_UBULK_COPY arguments + // 0: uint64_t* bulk_load_memory_barrier [if this is a BULK_LOAD_G2S] + cute::tuple opargs_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {{&bulk_mbar}}; + } +}; + +// +// MAKE_TMA_COPY and related +// + +namespace detail { + +// Custom version of coalesce that greedily combines modes only up to size-256 +// Look at each element and the back of the stack (in order of priority) +// back(NewLayout) get(OldLayout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_back s1:d1 +// s0:d0 s1:s0*d0 => replace_back s0*s1:d0 if s0*s1 <= 256 +// s0:d0 s1:d1 => append s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256_impl(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == rank_v) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return coalesce_256_impl(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return coalesce_256_impl(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_constant(old_stride) && + get(old_shape) * back(new_shape) <= Int<256>{})>::value) { + // Merge modes because the shapes and strides match and the merge is 256 or less + return coalesce_256_impl(old_shape, old_stride, + replace_back(new_shape, get(old_shape) * back(new_shape)), + new_stride); + } else { + // Can't replace or merge, so append a new mode + return coalesce_256_impl(old_shape, old_stride, + append(new_shape, get(old_shape)), + append(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine all the modes that are possible to combine +// Does not respect the profile of the layout, but does preserve total size +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +construct_tma_gbasis(Tensor const& gtensor, // The original GMEM Tensor + Layout const& slayout, // The layout of SMEM + Layout const& cta_v_map) // smem_idx to hier gmode +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + +#if 0 + print("gtensor : "); print(gtensor); print("\n"); + print("slayout : "); print(slayout); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); +#endif + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + + // Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode + // smem idx -> gmem mode + auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("inv_smem_layout : "); print(inv_smem_layout); print("\n"); + print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n"); +#endif + + // + // TMA gtensor truncation + // + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); + + // Keep only the static-1 basis modes into gmem + auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full); + +#if 0 + print("smem_rank : "); print(smem_rank); print("\n"); + print("sidx2gmode : "); print(sidx2gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // The smem vector is the same units as gtensor, so compose first and then recast + // tma_val_idx:gmem_strides + auto tile_gstride = recast(gtensor.compose(sidx2gmode)).layout(); + // Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType) + // tma_box_shape:gmem_strides + auto tma_gstride = coalesce_256(tile_gstride); + + // Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes + auto gbasis = make_identity_layout(shape(gtensor)); + auto tile_gbasis_tmp = gbasis.compose(sidx2gmode); + + // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape + // tma_box_shape:gmem_mode + auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp)); + + // "Coalesce" the tile basis into a compatible shape with the tma_gstride + auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + // Find missing bases that don't appear in tile_gbasis + auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)), + flatten(stride(gbasis)), + [&](auto s, auto d, auto e) + { + if constexpr (is_constant<1, decltype(s)>::value || is_constant<0, decltype(d)>::value) { + return cute::tuple<>{}; // If size-1 or stride-0, then don't append + } else { + using E = decltype(e); + auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; }); + if constexpr (decltype(has_e)::value) { + return cute::tuple<>{}; // If d was found, then don't append + } else { + return cute::tuple(e); // Else, this is missing so append + } + } + }); + + // Append the remaining basis modes that contribute to the TMA with size-1 + auto tile_gbasis_remaining_shape = repeat(Int<1>{}); + auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )), + tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride))); + + // Group the trailing modes to make this max rank-5 -- TMA rank limitation + // tma_box_shape:gmem_mode + auto tma_gbasis = group(tma_gbasis_full); + +#if 0 + print("tile_gstride : "); print(tile_gstride); print("\n"); + print("tma_gstride : "); print(tma_gstride); print("\n"); + print("gbasis : "); print(gbasis); print("\n"); + print("tile_gbasis : "); print(tma_gbasis_tile); print("\n"); + print("tma_gbasis : "); print(tma_gbasis); print("\n"); +#endif + + return tma_gbasis; +} + +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Tensor const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType + TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s) + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + static_assert(is_tuple::value); + static_assert(is_same::value || is_same::value); + + using TmaInternalType = typename GEngine::value_type; + constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value; + static_assert(TmaRank >= tma_rank); + + auto gmem_shape = shape(gtensor); + auto gmem_stride = stride(gtensor); + // Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides + for_each(make_seq{}, [&](auto i) { + constexpr int tma_i_rank = decltype(rank(tma_gbasis_stride))::value; + if constexpr (tma_i_rank == 1) { + // Trivial contribution of this gmem mode to this tma mode + auto ej = unwrap(get(tma_gbasis_stride)); + gmem_prob_shape[i] = basis_get(ej, gmem_shape); + gmem_prob_stride[i] = basis_get(ej, gmem_stride); + } else { + // Apply a recurrence to each gmem mode that contributes to this tma mode + for_each(get(tma_gbasis_stride), [&](auto ej) { + // Problem shape + uint64_t shape_j = basis_get(ej, gmem_shape); + // Problem stride (in bytes) + uint64_t stride_j = basis_get(ej, gmem_stride); + uint64_t old_stride = gmem_prob_stride[i]; + gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); + + if (gmem_prob_stride[i] != 0) { + // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 + gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) + + (shape_j-1) * (stride_j / gmem_prob_stride[i]) + + 1; + } else { + gmem_prob_shape[i] = shape_j; + } + }); + } + }); +} + +// Overload for an existing Copy_Traits +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Copy_Traits const& tma_traits, + Tensor const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}), + gmem_prob_shape, gmem_prob_stride); +} + +// Use a sidx2gmode to read through the GMEM tensor +// and construct a TMA Descriptor for the resulting instruction +// At the same time, construct the Tma Tensor's Stride to generate +// the TMA coordinates that the instruction consumes. +// +template +CUTE_HOST_RTC +auto +make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor + Layout const& tma_gbasis, // TMA mode -> GMEM mode mapping + Swizzle const& swizzle, // Swizzle fn on smem_idx + uint32_t num_multicast) // The number of CTAs in multicasting +{ + // + // TMA desc creation + // + + constexpr int tma_dim = decltype(rank(tma_gbasis))::value; + + // + // TMA gmem desc info + // + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data()); + auto gmem_layout = gtensor_T.layout(); + + cute::array gmem_prob_shape = {1,1,1,1,1}; + cute::array gmem_prob_stride = {0,0,0,0,0}; + + fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride); + + assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned + + assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 + + // TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element). + assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem"); + + // convert strides to byte strides + for(uint64_t& stride : gmem_prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + // Assert the byte strides. Tma Descriptor uses byte strides + assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + + // + // TMA smem desc info + // + + cute::array smem_box_shape = {1,1,1,1,1}; + cute::array smem_box_stride = {1,1,1,1,1}; + // The smem box is simply given by the sizes of the modes in tma_gbasis + for_each(make_seq{}, [&](auto i) { + smem_box_shape[i] *= size(tma_gbasis); + }); + // Finally, truncate the tma box by the num_multicast + for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) { + assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0); + uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]); + smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast); + multicast = new_mult; + } + + assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[1] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[1] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[2] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[2] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[3] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[3] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[4] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[4] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + + assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + + // + // Construct the descriptor + // + + TmaDescriptor tma_desc{}; + + // + // TMA general info + // + + #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + // TMA smem swizzle type + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); + CUresult result = cuTensorMapEncodeTiled( + &tma_desc, + tma_format, + tma_dim, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 + smem_box_shape.data(), + smem_box_stride.data(), + tma_interleave, + smem_swizzle, + tma_l2Promotion, + tma_oobFill); + + if (result != CUDA_SUCCESS) { + std::cerr << "TMA Desc Addr: " << &tma_desc + << "\nformat " << tma_format + << "\ndim " << tma_dim + << "\ngmem_address " << gmem_address + << "\nglobalDim " << gmem_prob_shape + << "\nglobalStrides " << gmem_prob_stride + << "\nboxDim " << smem_box_shape + << "\nelementStrides " << smem_box_stride + << "\ninterleave " << tma_interleave + << "\nswizzle " << smem_swizzle + << "\nl2Promotion " << tma_l2Promotion + << "\noobFill " << tma_oobFill << std::endl; + std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + assert(false); + } + + #endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + auto recast_ratio = cute::trait_ratio(sizeof_bits{}, + sizeof_bits< TmaInternalType>{}); + + auto gbasis = make_basis_like(shape(gtensor)); + + // Finally, get the inverse permutation of the E bases for the mocked gmem stride + auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) { + auto si = basis_get(ei, shape(gmem_layout)); + auto di = basis_get(ei, stride(gmem_layout)); + if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) { + return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA + } else { + auto tma_gmem_basis_stride = stride(tma_gbasis); + // Find j such that E is in stride(tma_gbasis) + using EI = decltype(ei); + [[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); + if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) { + return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA + } else + if constexpr (decltype(j == Int<0>{})::value) { + auto scale = recast_ratio * basis_get(ei, stride(gtensor)); + return E{} * scale; // Return TMA Coord basis -- with a recast scale factor + } else + if constexpr (decltype(rank(tma_gmem_basis_stride) == Int<1>{})::value) { + return E{}; // Return TMA Coord basis -- known scale of Int<1>{} + } else { + int32_t scale = ceil_div(int32_t(di * sizeof_bits_v / cute::max(gmem_prob_stride[j], uint64_t{16})), 8); + return E{} * scale; // Return TMA Coord basis -- with a dynamic scale factor + } + } + }); + +#if 0 + print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n"); +#endif + + using AuxParams = AuxTmaParams; + return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride}); +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_atom(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + uint32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + // + // TMA truncated layout + // + + auto smem_swizzle = get_swizzle_portion(slayout); + auto smem_layout = get_nonswizzle_portion(slayout); + + auto tma_gbasis = detail::construct_tma_gbasis(gtensor, smem_layout, cta_v_map); + + // + // Construct the TMA Desc and the strides of the TMA Tensor + // + + auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, + tma_gbasis, + smem_swizzle, + num_multicast); + + // + // Construct the Copy_Traits + // + + constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits_v; + using Traits = Copy_Traits, decltype(aux_params)>; + using Atom = Copy_Atom; + + Traits tma_traits{tma_desc, aux_params}; + +#if 0 + print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n"); + print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n"); +#endif + + // Return the Copy_Atom + return Atom{tma_traits}; +} + +// The "logical TMA tid" is a map from the CTA rank to its logical id +// within the instruction. It works like a mask or ordering on the +// CTAs. For non-multicast TMA, all CTAs should map to 0. For +// multicast TMA of size 4, CTAs will be mapped to {0,1,2,3}. +template +CUTE_HOST_RTC +auto +make_tma_copy_tiled(CopyOp const& copy_op, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM + Layout const& cta_t_map, // T: CTA thr idx -> logical TMA tid + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + Copy_Atom atom = make_tma_copy_atom(copy_op, gtensor, slayout, + cosize(cta_t_map), cta_v_map); + + // + // Construct the TiledCopy + // + + [[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + return TiledCopy{atom}; +} + +} // end namespace detail + +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. + * This is often the blk_shape that is used to tile the GMEM for CTAs: + * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor + * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 + * defining the multicast size (used to further partition the SMEM) + * Else, static-1 + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + using T = float; + T* gptr = nullptr; + + { + // Simple 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM + auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // GMMA 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM + auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // 3D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM + auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // cuTENSOR 4D + auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM + auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto slayout = make_layout(cta_tile); // Col-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); + } + * + * Check the TMA box size and desc: + print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + print("TMA desc : "); print(tma.tma_desc_); print("\n"); + * + * Usage: + Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor + Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA + Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor + + auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning + Tensor tAgA = cta_tma.partition_S(gA); // Partition for src + Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst + + copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params + */ +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler, + cluster_size); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, + gtensor, slayout, + cta_t_tile, cta_v_tile); + } +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Cluster_Size const& cluster_size) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); +} + +//////////////////////////////////// +// Experimental Make TMA Atom and Partitioner +/////////////////////////////////// + +template +CUTE_HOST_RTC +auto +make_tma_atom(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, + gtensor, slayout, + size(cluster_size), cta_v_tile); +} + +// The "VectorCopy Partitioner" for TMA +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + CtaCoord const& cta_coord, + Layout const& cta_layout, // T: CTA coord -> logical multicast id + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + CUTE_STATIC_ASSERT_V(size<0>(stensor) == size<0>(gtensor)); + + // Invert the smem to get the largest contiguous vector in the smem layout + Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor))); + // Scale that up to cover all of the smem_coords + Layout layout_v = tile_to_shape(make_layout(inv_smem_layout), size<0>(stensor)); + + // Factor out the single-instrucion portion + Layout tma_layout_v = make_layout(Int::NumValSrc>{}); + auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); + + // Append with _ until we cover all Rest... modes + auto glayout_V = append>(layout_V, _); + auto slayout_V = append>(layout_V, _); + // Transform tile mode and coalesce + Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + +#if 0 + if (thread0()) { + print("gtensor : "); print(gtensor); print("\n"); + print("stensor : "); print(stensor); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("gtensor_v : "); print(gtensor_v); print("\n"); + print("stensor_v : "); print(stensor_v); print("\n"); + } +#endif + + // Restride the cta-into-tma-instr layout + Layout tma_layout_t = composition(make_layout(Int<1>{}, shape_div(size(tma_layout_v), cosize(cta_layout))), cta_layout); + auto tma_layout_tv = make_tile(make_tile(make_layout(tma_layout_t, tma_layout_v), _)); + + // Append with _ until we cover all Rest... modes + auto gtma_layout_tv = append>(tma_layout_tv, _); + auto stma_layout_tv = append>(tma_layout_tv, _); + + // Transform TMA mode + Tensor gtensor_tv = gtensor_v.compose(gtma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...) + Tensor stensor_tv = stensor_v.compose(stma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...) + +#if 0 + if (thread0()) { + print("tma_layout_tv : "); print(tma_layout_tv); print("\n"); + print("gtensor_tv : "); print(gtensor_tv); print("\n"); + print("stensor_tv : "); print(stensor_tv); print("\n"); + } +#endif + + auto c = make_coord(make_coord(make_coord(cta_coord, _), _)); + auto c_s = append>(c, _); + auto c_g = append>(c, _); + + return cute::make_tuple(group_modes<0,2>(gtensor_tv(c_g)), group_modes<0,2>(stensor_tv(c_s))); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_tma_swizzle.hpp new file mode 100644 index 00000000..bb44a835 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/// @file copy_traits_sm90_tma_swizzle.hpp +/// @brief Functions for converting swizzle layout to TMA descriptor + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include + +namespace cute::detail { + +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Swizzle) +{ + if constexpr (M == 4) { + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + case 3: return TMA::SmemSwizzleBits::B128; + case 2: return TMA::SmemSwizzleBits::B64; + case 1: return TMA::SmemSwizzleBits::B32; + case 0: return TMA::SmemSwizzleBits::DISABLE; + } + } else + { + static_assert(M < 0, "Unsupported layout swizzle."); + } +} + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Layout const& layout) +{ + return get_tma_swizzle_bits(get_swizzle_portion(layout)); +} + +} // namespace cute::detail diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_atom.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_atom.hpp new file mode 100644 index 00000000..674e3519 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_atom.hpp @@ -0,0 +1,949 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include +#include + +namespace cute { + +template +struct MMA_Atom; + +template +struct MMA_Atom : MMA_Atom> +{}; + +template +struct MMA_Atom> + : MMA_Traits +{ + using Traits = MMA_Traits; + + // Element value types from the MMA_Traits + using ValTypeD = typename Traits::ValTypeD; + using ValTypeA = typename Traits::ValTypeA; + using ValTypeB = typename Traits::ValTypeB; + using ValTypeC = typename Traits::ValTypeC; + + // Thr-Val layouts from the MMA_Traits + using Shape_MNK = typename Traits::Shape_MNK; + using ThrID = typename Traits::ThrID; + using LayoutC_TV = typename Traits::CLayout; + using LayoutA_TV = typename Traits::ALayout; + using LayoutB_TV = typename Traits::BLayout; + + // Fragment value types from the MMA_Traits (optional, defaults to Val type) + using FrgTypeD = typename detail::FrgTypeC_or_Default::type; + using FrgTypeA = typename detail::FrgTypeA_or_Default::type; + using FrgTypeB = typename detail::FrgTypeB_or_Default::type; + using FrgTypeC = typename detail::FrgTypeC_or_Default::type; + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return MMA_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Cast, check, and call fma + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) const + { + static_assert(DLayout::rank == 1, "Expected rank-1 D tensor"); + static_assert(ALayout::rank == 1, "Expected rank-1 A tensor"); + static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); + static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); + + return mma_unpack(*this, D, A, B, C); + } + + // Three arguments reproduces C + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor const& A, + Tensor const& B, + Tensor & C) const + { + return call(C, A, B, C); + } + + // + // make_fragment_A|B|C + // These functions are awkward as they expect already-partitioned tensors + // resulting from a previous call to partition_A|B|C + // The reasoning is that we can inspect the layout of the partitioned data + // and attempt to match it in generated fragment to promote vectorization + // when copying from partition to fragment. + // + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_C(CTensor&& ctensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN + CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); + // C is a bit special because we are after accumulators here + // The input/output type doesn't have to match the accumulator type + //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); + + // We'll never base the accumulator layout on the input tensor layout, so just return a FrgTypeC tensor + return make_tensor(shape(ctensor)); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_A(ATensor&& atensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK + CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeA is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + , "Expecting ValTypeA type"); + return make_tensor(static_cast(atensor)); + } else { + // Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout + return make_fragment_like(atensor); + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_B(BTensor&& btensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK + CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeB is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + , "Expecting ValTypeB type"); + return make_tensor(static_cast(btensor)); + } else { + // Else, the intended FrgTypeB is a value type, construct a new tensor with a fragment layout + return make_fragment_like(btensor); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// +// A tiling of mma atoms +// + +template +struct ThrMMA; + +// @tparam MMA_Atom The MMA_Atom to use in the TiledMMA +// @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed. +// @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom. +template > +struct TiledMMA : MMA_Atom +{ + using Atom = MMA_Atom; + using AtomShape_MNK = typename MMA_Atom::Shape_MNK; + using AtomThrID = typename MMA_Atom::ThrID; + using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; + using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; + using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; + + static_assert( rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); + static_assert( rank_v == 3, "TiledMMA requires rank-3 PermutationMNK"); + static_assert( is_tuple::value, "TiledMMA requires independent permutations of MNK."); + static_assert(is_static::value, "TiledMMA requires static permutations of MNK."); + + using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); + ThrLayoutVMNK thr_layout_vmnk_; + + CUTE_HOST_DEVICE constexpr + TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {}) + : MMA_Atom(mma_atom), + thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {} + + CUTE_HOST_DEVICE constexpr auto + get_thr_layout_vmnk() const { + return thr_layout_vmnk_; + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_C(CTensor&& ctensor) const + { + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); + //CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<1>(PermutationMNK{})); + auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<1>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomN),(RestM,RestN)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) + + // Tile the tensor for the C-threads + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<2>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (M,K,...) + // to shape + // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_A(ATensor&& atensor) const + { + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); + //CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (N,K,...) + // to shape + // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestN: The values tiled in N. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_B(BTensor&& btensor) const + { + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); + //CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(PermutationMNK{}), + get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + + return thr_tensor; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_slice(ThrIdx const& thr_idx) const + { + auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx); + return ThrMMA{*this, thr_vmnk}; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_thread_slice(ThrIdx const& thr_idx) const + { + return get_slice(thr_idx); + } + + // + // Utility for printing and visualization + // + + // The size of the MNK-mode + template + CUTE_HOST_DEVICE constexpr + auto + tile_size_mnk() const { + static_assert(0 <= I && I < 3); + auto core_size = size(AtomShape_MNK{}) * size(get_thr_layout_vmnk()); + [[maybe_unused]] auto perm_size = size(PermutationMNK{}); + if constexpr (is_underscore::value) { + return core_size; + } else { + return cute::max(core_size, perm_size); + } + + CUTE_GCC_UNREACHABLE; + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutC_MN() const + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); + // (cthrid,val) -> (M,N) + auto layoutC_TV = thrfrg_C(ref_C); + // (M,N) -> (cthrid,frg) + auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); + + // cthrid = (v,m,n) -> thr_idx + auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{}); + + return cute::make_tuple(layoutC_MN, thrID_C); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutC_TV() const + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); + // (cthrid,val) -> (M,N) + auto layoutC_TV = thrfrg_C(ref_C); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (M,N) + return layoutC_TV.compose(thridx_2_thrid, _); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutA_MK() const + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); + // (athrid,val) -> (M,K) + auto layoutA_TV = thrfrg_A(ref_A); + // (M,K) -> (athrid,frg) + auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); + + // athrid = (v,m,k) -> thr_idx + auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_); + + return cute::make_tuple(layoutA_MK, thrID_A); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutA_TV() const + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); + // (athrid,val) -> (M,K) + auto layoutA_TV = thrfrg_A(ref_A); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutB_NK() const + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); + // (bthrid,val) -> (N,K) + auto layoutB_TV = thrfrg_B(ref_B); + // (N,K) -> (bthrid,frg) + auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); + + // bthrid = (v,n,k) -> thr_idx + auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_); + + return cute::make_tuple(layoutB_NK, thrID_B); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutB_TV() const + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); + // (bthrid,val) -> (N,K) + auto layoutB_TV = thrfrg_B(ref_B); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (N,K) + return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _); + } +}; + +template +struct ThrMMA : TiledMMA +{ + ThrVMNK thr_vmnk_; + + template + CUTE_HOST_DEVICE constexpr + auto + partition_C(CTensor&& ctensor) const + { + auto thr_tensor = make_tensor(static_cast(ctensor).data(), this->thrfrg_C(ctensor.layout())); + + auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); + return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_A(ATensor&& atensor) const + { + auto thr_tensor = make_tensor(static_cast(atensor).data(), this->thrfrg_A(atensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_B(BTensor&& btensor) const + { + auto thr_tensor = make_tensor(static_cast(btensor).data(), this->thrfrg_B(btensor.layout())); + + auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_C(CTensor&& ctensor) const + { + return TiledMMA::make_fragment_C(partition_C(ctensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_A(ATensor&& atensor) const + { + return TiledMMA::make_fragment_A(partition_A(atensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_B(BTensor&& btensor) const + { + return TiledMMA::make_fragment_B(partition_B(btensor)); + } +}; + +// +// These tile the MMA_Atom as a whole +// + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Atom const& mma_atom, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{}); + auto permutation_mnk = append<3>(permutations, _); + + return TiledMMA, + decltype(thr_layout_mnk), + decltype(permutation_mnk)>{mma_atom, thr_layout_mnk}; +} + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Op const&, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + // Attempt to wrap in an MMA_Atom<> and forward + return make_tiled_mma(MMA_Atom{}, thr_layout, permutations); +} + +// +// partition_fragment_C -- static context +// + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_C(TiledMMA const& mma, Shape_MN const& shape_MN) +{ + constexpr int R = rank_v; + static_assert(R >= 2, "Must have at least rank-2"); + auto atomMNK = typename TiledMMA::AtomShape_MNK{}; + auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; + auto V = shape<1>(typename TiledMMA::AtomLayoutC_TV{}); + auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK)); + auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK)); + return cute::tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_C(TiledMMA const& mma, Shape_MN const& shapeMN) +{ + return make_tensor::FrgTypeC>(partition_shape_C(mma, shapeMN)); +} + +// partition_fragment_A and partition_fragment_B often depend on the +// layout of A and B and/or the thread_idx that is requesting the partition. +// For these reasons, they should not be used in a static context. +// See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_A(TiledMMA const& mma, Shape_MK const& shape_MK) +{ + constexpr int R = rank_v; + static_assert(R >= 2, "Must have at least rank-2"); + auto atomMNK = typename TiledMMA::AtomShape_MNK{}; + auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; + auto V = shape<1>(typename TiledMMA::AtomLayoutA_TV{}); + auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK)); + auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK)); + return cute::tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) +{ + constexpr int R = rank_v; + static_assert(R >= 2, "Must have at least rank-2"); + auto atomMNK = typename TiledMMA::AtomShape_MNK{}; + auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; + auto V = shape<1>(typename TiledMMA::AtomLayoutB_TV{}); + auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK)); + auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK)); + return cute::tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK)); +} + +// +// Size +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledMMA const& mma) +{ + return mma.template tile_size_mnk(); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_shape(TiledMMA const& mma) +{ + return make_shape(tile_size<0>(mma), tile_size<1>(mma), tile_size<2>(mma)); +} + +// Deprecate? +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// Alias +template +CUTE_HOST_DEVICE constexpr +auto +thr_size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(MMA_Atom> const&) +{ + using Atom = MMA_Atom>; + print("MMA_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); + print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); + print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledMMA const& mma) +{ + print("TiledMMA\n"); + print(" ThrLayoutVMNK: "); print(mma.get_thr_layout_vmnk()); print("\n"); + print(" PermutationMNK: "); print(TiledPerm{}); print("\n"); + print(static_cast(mma)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrMMA const& thr_mma) +{ + print("ThrMMA\n"); + print(" Thr VMNK: "); print(thr_mma.thr_vmnk_); print("\n"); + print(static_cast(thr_mma)); +} + +template +CUTE_HOST_DEVICE +void +print_latex(MMA_Atom const& mma_atom) +{ + print_latex(make_tiled_mma(mma_atom)); +} + +template +CUTE_HOST_DEVICE +void +print_latex(TiledMMA const& mma) +{ + auto layout_and_thrid_C = mma.get_layoutC_MN(); + auto layoutC_MN = get<0>(layout_and_thrid_C); + auto thrID_C = get<1>(layout_and_thrid_C); + + auto layout_and_thrid_A = mma.get_layoutA_MK(); + auto layoutA_MK = get<0>(layout_and_thrid_A); + auto thrID_A = get<1>(layout_and_thrid_A); + + auto layout_and_thrid_B = mma.get_layoutB_NK(); + auto layoutB_NK = get<0>(layout_and_thrid_B); + auto thrID_B = get<1>(layout_and_thrid_B); + + print_latex_mma(layoutC_MN, thrID_C, + layoutA_MK, thrID_A, + layoutB_NK, thrID_B); +} + +// MNK MMA Layout to console printer +template +CUTE_HOST_DEVICE +void +print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + int a_width = size<1>(A) * 6 + 4; + + // Print out B (white-shifted) k-by-n + for (int k = 0; k < size<1>(B); ++k) { + // Header + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n"); + // Values + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); + printf("|\n"); + } + // Footer + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n\n"); + + // Print out A m-by-k and C m-by-n + for (int m = 0; m < size<0>(A); ++m) { + // Header + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); + // Values + for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); + printf("| "); + for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); + printf("|\n"); + } + // Footer + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); +} + +// MNK MMA Layout to Latex TIKZ -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + char const* latex_header = + "\\documentclass{standalone}\n" + "\\usepackage{tikz}\n" + "\\usetikzlibrary{external}\n" + "\\tikzexternalize\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + + // Header + printf("%% LayoutC: "); print(C); printf("\n"); + printf("%% ThrIDC : "); print(TC); printf("\n"); + printf("%% LayoutA: "); print(A); printf("\n"); + printf("%% ThrIDA : "); print(TA); printf("\n"); + printf("%% LayoutB: "); print(B); printf("\n"); + printf("%% ThrIDB : "); print(TB); printf("\n\n"); + + printf(latex_header); + + // C starting at 0,0 + for (int m = 0; m < size<0>(C); ++m) { + for (int n = 0; n < size<1>(C); ++n) { + int thrid = C(m,n) % size(TC); + int val_idx = C(m,n) / size(TC); + int thr_idx = TC(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + m, n, + thr_idx, val_idx); + } + } + + // A starting at 0,-size<1>(A)-1 + for (int m = 0; m < size<0>(A); ++m) { + for (int k = 0; k < size<1>(A); ++k) { + int thrid = A(m,k) % size(TA); + int val_idx = A(m,k) / size(TA); + int thr_idx = TA(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + m, k-1-size<1>(A), + thr_idx, val_idx); + } + } + + // B starting at -size<1>(B)-1,0 + for (int n = 0; n < size<0>(B); ++n) { + for (int k = 0; k < size<1>(B); ++k) { + int thrid = B(n,k) % size(TB); + int val_idx = B(n,k) / size(TB); + int thr_idx = TB(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + k-1-size<1>(B), n, + thr_idx, val_idx); + } + } + + // A labels + for (int m = 0, k = -1; m < size<0>(A); ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); + } + for (int k = 0, m = -1; k < size<1>(A); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); + } + // B labels + for (int n = 0, k = -1; n < size<0>(B); ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); + } + for (int k = 0, n = -1; k < size<1>(B); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); + } + + // Footer + printf(latex_footer); +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_traits.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_traits.hpp new file mode 100644 index 00000000..8c090936 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_traits.hpp @@ -0,0 +1,228 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +namespace detail { + +template +struct supports_output_scaling { static constexpr bool value = false; }; + +template +struct supports_output_scaling().accumulate_)>> { static constexpr bool value = true; }; + +} // end namespace detail + +/** + * concept MMA_Traits + * { + * using ValTypeD = // Logical A-value type + * using ValTypeA = // Logical B-value type + * using ValTypeB = // Logical C-value type + * using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD) + * + * using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) + * using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) + * using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC) + * + * using Shape_MNK = // Logical MxNxK shape of the MMA + * + * using ThrID = // Logical thread id (tid) -> tidx + * + * using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord + * using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord + * using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord + * }; + */ + +template +struct MMA_Traits +{ + static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = D; + using ValTypeA = A; + using ValTypeB = B; + using ValTypeC = C; + + // Logical shape of the MMA + using Shape_MNK = Shape<_1,_1,_1>; + + // Logical thread id (tid) -> tidx + using ThrID = Layout<_1>; + + // (Logical thread id (tid), Logical value id (vid)) -> coord + + // (tid,vid) -> (m,k) + using ALayout = Layout>; + // (tid,vid) -> (n,k) + using BLayout = Layout>; + // (tid,vid) -> (m,n) + using CLayout = Layout>; +}; + +// +// Generic mma_unpack for any MMA_Traits +// +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using MMATraits = MMA_Traits; + + [[maybe_unused]] constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + Tensor rA = recast(A); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + if constexpr (is_same::value) + { + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + if constexpr (detail::supports_output_scaling::value) { + detail::explode_with_d_scaling(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + traits.accumulate_); + } + else { + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); + } + } + else { + Tensor rD = recast(D); + Tensor rC = recast(C); + + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + if constexpr (detail::supports_output_scaling::value) { + detail::explode_with_d_scaling(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + traits.accumulate_); + } + else { + detail::explode(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); + } + } +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + mma_unpack(traits, D, A, B, C); +} + +namespace detail { + +template +struct FrgTypeA_or_Default { using type = typename X::ValTypeA; }; +template +struct FrgTypeA_or_Default> { using type = typename X::FrgTypeA; }; + +template +struct FrgTypeB_or_Default { using type = typename X::ValTypeB; }; +template +struct FrgTypeB_or_Default> { using type = typename X::FrgTypeB; }; + +template +struct FrgTypeC_or_Default { using type = typename X::ValTypeC; }; +template +struct FrgTypeC_or_Default> { using type = typename X::FrgTypeC; }; + +} // end namespace detail + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm61.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm61.hpp new file mode 100644 index 00000000..f72a6394 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm61.hpp @@ -0,0 +1,73 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_4>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int16_t; + using ValTypeB = int16_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_2>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm70.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm70.hpp new file mode 100644 index 00000000..d1c8b61f --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm70.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +namespace { + +// Logical thread id to thread idx (quadpair) +using SM70_QuadPair = Layout, + Stride<_1,_16>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Row = Layout, + Stride<_1,_8>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Col = Layout,_4>, + Stride,_1>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_16b = Layout, + Stride<_1,_8>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, + Stride,Stride<_8,_2,_32>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm75.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm75.hpp new file mode 100644 index 00000000..1d3f5196 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm75.hpp @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = Layout,_2>, + Stride,_8>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,_4>, + Stride,_8>>; + using BLayout = Layout,_4>, + Stride,_8>>; + using CLayout = Layout,_2>, + Stride,_8>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm80.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm80.hpp new file mode 100644 index 00000000..ab402881 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm80.hpp @@ -0,0 +1,442 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V1) -> (M8,N8) +using SM80_8x4 = Layout,_1>, + Stride,_0>>; +// (T32,V2) -> (M8,N8) +using SM80_8x8_Row = Layout,_2>, + Stride,_8>>; +// (T32,V4) -> (M8,N16) +using SM80_8x16_Row = Layout,_4>, + Stride,_8>>; +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp16 = fp16 * fp16 + fp16 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = SM80_16x8_Row; + using BLayout = SM80_8x8_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_16,_8,_128>>>; + using BLayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = fp16 * fp16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = bf16 * bf16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = tf32 * tf32 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = SM80_8x4; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x4; + using BLayout = SM80_8x4; + using CLayout = SM80_8x8_Row; +}; + +// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x16_Row; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout, Shape <_4, _2>>, + Stride, Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_256>; + using ThrID = Layout<_32>; + using ALayout = Layout>, + Stride<_64,Stride<_64,_16,_8,_2048>>>; + using BLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + using CLayout = SM80_16x8_Row; +}; +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm90.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm90.hpp new file mode 100644 index 00000000..b4406389 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm90.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute { + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = Layout,_1>, + Stride,_0>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _4>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _4>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////////// +//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm90_gmma.hpp b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm90_gmma.hpp new file mode 100644 index 00000000..3bbcd1fb --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/atom/mma_traits_sm90_gmma.hpp @@ -0,0 +1,5929 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute { + +// Fence between the async destination accumulators of GMMA & source for their dependent use +template +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(Tensor& frg) { + CUTE_STATIC_ASSERT(is_static::value); + if constexpr (is_same_v) { + auto f32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(f32_frg); ++i) { + warpgroup_fence_operand(f32_frg(i)); + } + } + else { + CUTE_STATIC_ASSERT(is_rmem::value); + auto u32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(u32_frg); ++i) { + warpgroup_fence_operand(u32_frg(i)); + } + } +} + +namespace GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major GMMA layouts in units of bits +using Layout_MN_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _128>>>; +using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; +using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; +using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; + +// K-major GMMA layouts in units of bits +using Layout_K_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _128,_1>>>; +using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; +using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; +using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; + +// M|N-major layouts in units of Type +template +using Layout_MN_INTER_Atom = decltype(upcast::value>(Layout_MN_INTER_Atom_Bits{})); +template +using Layout_MN_SW32_Atom = decltype(upcast::value>(Layout_MN_SW32_Atom_Bits{})); +template +using Layout_MN_SW64_Atom = decltype(upcast::value>(Layout_MN_SW64_Atom_Bits{})); +template +using Layout_MN_SW128_Atom = decltype(upcast::value>(Layout_MN_SW128_Atom_Bits{})); + +// K-major layouts in units of Type +template +using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); +template +using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); +template +using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); +template +using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); + +// With GMMA::Major param +template +using Layout_INTER_Atom = typename conditional, + Layout_K_INTER_Atom>::type; +template +using Layout_SW32_Atom = typename conditional, + Layout_K_SW32_Atom>::type; +template +using Layout_SW64_Atom = typename conditional, + Layout_K_SW64_Atom>::type; +template +using Layout_SW128_Atom = typename conditional, + Layout_K_SW128_Atom>::type; + +// +// Tensor (position-dependent swizzle) to LayoutType utility +// + +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor> const&) +{ + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + + static_assert(M == 4, "Unsupported layout swizzle"); + static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); + static_assert(S == 3, "Unsupported layout swizzle"); + + switch (B) { + case 0: return LayoutType::INTERLEAVE; + case 1: return LayoutType::B32; + case 2: return LayoutType::B64; + case 3: return LayoutType::B128; + } + return LayoutType::INTERLEAVE; // ERROR +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for GMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_gmma_desc // +* /////////////////////////////// +* Each GmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to GMMA shape +* k : integer in [1,32] corresponding to GMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_gmma_desc // +* ////////////////////////////// +* Each GmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +make_gmma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "GMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "GMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + GmmaDescriptor desc; + + // Layout type + constexpr GMMA::LayoutType LAYOUT_TYPE = GMMA::layout_type(u128_tensor); + desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); + desc.bitfield.start_address_ = start_address >> 4; + + constexpr uint8_t base_offset = 0; + desc.bitfield.base_offset_ = base_offset; + + // LayoutType meta + constexpr int W = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? 1 : + LAYOUT_TYPE == GMMA::LayoutType::B32 ? 2 : + LAYOUT_TYPE == GMMA::LayoutType::B64 ? 4 : + LAYOUT_TYPE == GMMA::LayoutType::B128 ? 8 : -1; + + if constexpr (MajorMode == GMMA::Major::MN) + { + /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + */ + static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{}, // K size + "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits."); + + // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = W; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; + } + else if constexpr (MajorMode == GMMA::Major::K) + { + /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + */ + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); + CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{}, // K size + "Not a canonical GMMA_K Layout: Expected K-size 2 (in units of uint128_t)."); + + // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = W; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = stride_01; + desc.bitfield.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != GMMA::Major::MN && MajorMode != GMMA::Major::K, "Unrecognized MajorMode!"); + } + +#if 0 + // DEBUG and SANITY + assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation + assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later + if (thread0()) { + print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); + print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); + //print(" desc canonical layout: "); print(canonical_layout); print("\n"); + print(desc); + } +#endif + + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct DescriptorIterator +{ + using reference = GmmaDescriptor; + using element_type = GmmaDescriptor; + using value_type = GmmaDescriptor; + + GmmaDescriptor desc_; + + // Dereference returns the GmmaDescriptor + CUTE_HOST_DEVICE constexpr + reference operator*() const { return desc_; } + + // Advance and return a new GmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + DescriptorIterator operator+(Index const& offset) const + { + return { GmmaDescriptor{desc_ + uint64_t(offset)} }; + } + + CUTE_HOST_DEVICE friend void + print(DescriptorIterator) { printf("GMMA::DescriptorIterator"); } +}; + +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +// Recast a DescriptorIterator Tensor to uint64_t, it's RegType in mma_unpack +template +CUTE_HOST_DEVICE constexpr +DescriptorIterator +recast_ptr(DescriptorIterator const& iter) { + static_assert(is_same::value, "Can only cast GmmaDescriptorIterator to uint64_t."); + return iter; // Do nothing, it will still dereference to GmmaDescriptor and decay to uint64_t +} + +// The GMMA Traits below have custom fragment type flags for their smem desc tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. +template +struct smem_desc : DescriptorIterator {}; + +} // end namespace GMMA + +// Customization point for creating a GMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + return make_tensor(GMMA::DescriptorIterator{GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace GMMA { + +// Accumulator layouts +using CLayout_64x8 = Layout,Shape < _2,_2>>, + Stride,Stride<_64,_8>>>; + +using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x64 = Layout,Shape < _2,_2, _8>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x96 = Layout,Shape < _2,_2, _12>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x128 = Layout,Shape < _2,_2, _16>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, + Stride,Stride<_64,_8,_512>>>; + +using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, + Stride,Stride<_64,_8,_512>>>; + +// Register source layout for 32-bit value types +using ALayout_64x8 = Layout,Shape < _2, _2>>, + Stride,Stride< _8,_256>>>; + +// Register source layout for 16-bit value types +using ALayout_64x16 = CLayout_64x16; + +// Register source layout for 8-bit value types +using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_64,_8,_1024>>>; + +// Shared memory source layouts for any value type +template +using ABLayout = Layout,Int>>, + Stride< _0,Stride< _1,Int>>>; + +} // namespace GMMA + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/config.hpp b/server/punica_kernels/include/cutlass/cute/config.hpp new file mode 100644 index 00000000..941f60d7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/config.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) +# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ +# define CUTE_DEVICE __forceinline__ __device__ +# define CUTE_HOST __forceinline__ __host__ +#else +# define CUTE_HOST_DEVICE inline +# define CUTE_DEVICE inline +# define CUTE_HOST inline +#endif // CUTE_HOST_DEVICE, CUTE_DEVICE + +#if defined(__CUDACC_RTC__) +# define CUTE_HOST_RTC CUTE_HOST_DEVICE +#else +# define CUTE_HOST_RTC CUTE_HOST +#endif + +#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \ + (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) +# define CUTE_UNROLL #pragma unroll +# define CUTE_NO_UNROLL #pragma unroll 1 +#elif defined(__CUDACC_RTC__) || defined(__clang__) +# define CUTE_UNROLL _Pragma("unroll") +# define CUTE_NO_UNROLL _Pragma("unroll 1") +#else +# define CUTE_UNROLL +# define CUTE_NO_UNROLL +#endif // CUTE_UNROLL + +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +# define CUTE_INLINE_CONSTANT static const __device__ +#else +# define CUTE_INLINE_CONSTANT static constexpr +#endif + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTE_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTE_GRID_CONSTANT) +# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) +# define CUTE_GRID_CONSTANT __grid_constant__ +# else +# define CUTE_GRID_CONSTANT +# endif +#endif + +// Some versions of GCC < 11 have trouble deducing that a +// function with "auto" return type and all of its returns in an "if +// constexpr ... else" statement must actually return. Thus, GCC +// emits spurious "missing return statement" build warnings. +// Developers can suppress these warnings by using the +// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon. +// It's harmless to use the macro for other GCC versions or other +// compilers, but it has no effect. +#if ! defined(CUTE_GCC_UNREACHABLE) +# if defined(__clang__) || defined(__GNUC__) +# define CUTE_GCC_UNREACHABLE __builtin_unreachable() +# else +# define CUTE_GCC_UNREACHABLE +# endif +#endif + +#if defined(_MSC_VER) +// Provides support for alternative operators 'and', 'or', and 'not' +# include +#endif // _MSC_VER + +#if defined(__CUDACC_RTC__) +# define CUTE_STL_NAMESPACE cuda::std +# define CUTE_STL_NAMESPACE_IS_CUDA_STD +#else +# define CUTE_STL_NAMESPACE std +#endif + +// +// Assertion helpers +// + +#if defined(__CUDACC_RTC__) +# include +#else +# include +#endif + +#define CUTE_STATIC_V(x) decltype(x)::value + +#define CUTE_STATIC_ASSERT static_assert +#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) + +// Fail and print a message. Typically used for notification of a compiler misconfiguration. +#if defined(__CUDA_ARCH__) +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x); __brkpt() +#else +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x) +#endif + +// +// IO +// + +#if !defined(__CUDACC_RTC__) +# include +# include +# include +#endif + +// +// Support +// + +#include + +// +// Basic types +// + +#include + +// +// Debugging utilities +// + +#include +#include diff --git a/server/punica_kernels/include/cutlass/cute/container/alignment.hpp b/server/punica_kernels/include/cutlass/cute/container/alignment.hpp new file mode 100644 index 00000000..509579ee --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/alignment.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// Test if a pointer is aligned to N bytes +template +CUTE_HOST_DEVICE constexpr +bool +is_byte_aligned(void const* const ptr) +{ + static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check"); + return (reinterpret_cast(ptr) & (N-1)) == 0; +} + +#if defined(__CUDACC__) +# define CUTE_ALIGNAS(n) __align__(n) +#else +# define CUTE_ALIGNAS(n) alignas(n) +#endif + +template +struct aligned_struct {}; + +template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {}; +template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {}; +template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {}; +template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {}; +template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {}; +template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {}; +template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {}; +template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {}; +template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/container/array.hpp b/server/punica_kernels/include/cutlass/cute/container/array.hpp new file mode 100644 index 00000000..b40c523d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/array.hpp @@ -0,0 +1,492 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using iterator = pointer; + using const_iterator = const_pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return size() == 0; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return size(); + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + { + for (auto& e : *this) { + e = value; + } + } + + CUTE_HOST_DEVICE constexpr + void clear() + { + fill(T(0)); + } + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + { + using CUTE_STL_NAMESPACE::swap; + for (size_type i = 0; i < size(); ++i) { + swap((*this)[i], other[i]); + } + } + + element_type __elems_[N]; +}; + + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using const_iterator = const_pointer; + using iterator = pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return true; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + {} + + CUTE_HOST_DEVICE constexpr + void clear() + {} + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + {} +}; + +template +CUTE_HOST_DEVICE constexpr +bool operator==(array const& lhs, array const& rhs) +{ + for (size_t i = 0; i < N; ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +template +CUTE_HOST_DEVICE constexpr +void clear(array& a) +{ + a.fill(T(0)); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array& a, T const& value) +{ + a.fill(value); +} + +template +CUTE_HOST_DEVICE constexpr +void swap(array& a, array& b) +{ + a.swap(b); +} + +/// @return A cute::array of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +cute::array reverse(cute::array const& t) +{ + if constexpr (N == 0u) { + return t; + } else { + cute::array t_r{}; + for (size_t k = 0; k < N; ++k) { + t_r[k] = t[N - k - 1]; + } + return t_r; + } +} + +} // end cute + + +// +// Specialize tuple-related functionality for cute::array +// + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size const> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element const> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size const> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element const> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/server/punica_kernels/include/cutlass/cute/container/array_aligned.hpp b/server/punica_kernels/include/cutlass/cute/container/array_aligned.hpp new file mode 100644 index 00000000..9895a8da --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/array_aligned.hpp @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +namespace cute +{ + +template +struct CUTE_ALIGNAS(Alignment) array_aligned : cute::array {}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/container/array_subbyte.hpp b/server/punica_kernels/include/cutlass/cute/container/array_subbyte.hpp new file mode 100644 index 00000000..3ab3bc32 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/array_subbyte.hpp @@ -0,0 +1,634 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates subbyte trivial types + in a packed storage. +*/ + +#pragma once + +#include + +#include +#include + +namespace cute +{ +// +// Underlying subbyte storage type +// +template +using subbyte_storage_type_t = conditional_t<(cute::sizeof_bits_v <= 8), uint8_t, + conditional_t<(cute::sizeof_bits_v <= 16), uint16_t, + conditional_t<(cute::sizeof_bits_v <= 32), uint32_t, + conditional_t<(cute::sizeof_bits_v <= 64), uint64_t, + conditional_t<(cute::sizeof_bits_v <= 128), uint128_t, + T>>>>>; + +template struct subbyte_iterator; +template struct swizzle_ptr; + +// +// subbyte_reference +// Proxy object for sub-byte element references +// +template +struct subbyte_reference +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + // Bitmask for covering one item + static constexpr storage_type BitMask = storage_type(storage_type(-1) >> (sizeof_bits_v - sizeof_bits_v)); + // Flag for fast branching on straddled elements + static constexpr bool is_storage_unaligned = ((sizeof_bits_v % sizeof_bits_v) != 0); + + friend struct subbyte_iterator; + + // Pointer to storage element + storage_type* ptr_ = nullptr; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_ = 0; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_reference(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) {} + +public: + + // Copy Ctor + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = element_type(other); + } + + // Copy Assignment + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = element_type(other); + } + + // Assignment + template + CUTE_HOST_DEVICE constexpr + enable_if_t, subbyte_reference&> operator=(element_type x) + { + static_assert(is_same_v, "Do not specify template arguments!"); + storage_type item = (reinterpret_cast(x) & BitMask); + + // Update the current storage element + storage_type bit_mask_0 = storage_type(BitMask << idx_); + ptr_[0] = storage_type((ptr_[0] & ~bit_mask_0) | (item << idx_)); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Update the next storage element + ptr_[1] = storage_type((ptr_[1] & ~bit_mask_1) | (item >> straddle_bits)); + } + + return *this; + } + + // Comparison of referenced values + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_reference const& x, subbyte_reference const& y) { return x.get() == y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() != y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_reference const& x, subbyte_reference const& y) { return x.get() < y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_reference const& x, subbyte_reference const& y) { return x.get() > y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() <= y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() >= y.get(); } + + // Value + CUTE_HOST_DEVICE + element_type get() const + { + if constexpr (is_same_v) { // Extract to bool -- potentially faster impl + return bool((*ptr_) & (BitMask << idx_)); + } else { // Extract to element_type + // Extract from the current storage element + auto item = storage_type((ptr_[0] >> idx_) & BitMask); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Extract from the next storage element + item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); + } + + return reinterpret_cast(item); + } + } + + // Extract to type element_type + CUTE_HOST_DEVICE constexpr + operator element_type() const { + return get(); + } + + // Address + subbyte_iterator operator&() const { + return {ptr_, idx_}; + } +}; + +// +// subbyte_iterator +// Random-access iterator over subbyte references +// +template +struct subbyte_iterator +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + // Reference proxy type + using reference = subbyte_reference; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + template friend struct swizzle_ptr; + + // Pointer to storage element + storage_type* ptr_ = nullptr; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_ = 0; + +public: + + // Ctor + subbyte_iterator() = default; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_iterator(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) { } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return reference(ptr_, idx_); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator+=(uint64_t k) { + k = sizeof_bits_v * k + idx_; + ptr_ += k / sizeof_bits_v; + idx_ = k % sizeof_bits_v; + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator+(uint64_t k) const { + return subbyte_iterator(ptr_, idx_) += k; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](uint64_t k) const { + return *(*this + k); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator++() { + idx_ += sizeof_bits_v; + if (idx_ >= sizeof_bits_v) { + ++ptr_; + idx_ -= sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator++(int) { + subbyte_iterator ret(*this); + ++(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator--() { + if (idx_ >= sizeof_bits_v) { + idx_ -= sizeof_bits_v; + } else { + --ptr_; + idx_ += sizeof_bits_v - sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator--(int) { + subbyte_iterator ret(*this); + --(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; + } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); + } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return (y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x < y); } + + // Conversion to raw pointer with loss of subbyte index + CUTE_HOST_DEVICE constexpr friend + T* raw_pointer_cast(subbyte_iterator const& x) { + assert(x.idx_ == 0); + return reinterpret_cast(x.ptr_); + } + + // Conversion to NewT_ with possible loss of subbyte index + template + CUTE_HOST_DEVICE constexpr friend + auto recast_ptr(subbyte_iterator const& x) { + using NewT = conditional_t<(is_const_v), NewT_ const, NewT_>; + if constexpr (cute::is_subbyte_v) { // Making subbyte_iter, preserve the subbyte idx + return subbyte_iterator(x.ptr_, x.idx_); + } else { // Not subbyte, assume/assert subbyte idx 0 + return reinterpret_cast(raw_pointer_cast(x)); + } + CUTE_GCC_UNREACHABLE; + } + + CUTE_HOST_DEVICE friend void print(subbyte_iterator x) { + printf("subptr[%db](%p.%u)", int(sizeof_bits_v), x.ptr_, x.idx_); + } +}; + +// +// array_subbyte +// Statically sized array for non-byte-aligned data types +// +template +struct array_subbyte +{ + using element_type = T; + using value_type = remove_cv_t; + using pointer = element_type*; + using const_pointer = element_type const*; + + using size_type = size_t; + using difference_type = ptrdiff_t; + + // + // References + // + using reference = subbyte_reference; + using const_reference = subbyte_reference; + + // + // Iterators + // + using iterator = subbyte_iterator; + using const_iterator = subbyte_iterator; + + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + +private: + + // Number of storage elements, ceil_div + static constexpr size_type StorageElements = (N * sizeof_bits_v + sizeof_bits_v - 1) / sizeof_bits_v; + + // Internal storage + storage_type storage[StorageElements]; + +public: + + constexpr + array_subbyte() = default; + + CUTE_HOST_DEVICE constexpr + array_subbyte(array_subbyte const& x) { + CUTE_UNROLL + for (size_type i = 0; i < StorageElements; ++i) { + storage[i] = x.storage[i]; + } + } + + CUTE_HOST_DEVICE constexpr + size_type size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const { + return !N; + } + + // Efficient clear method + CUTE_HOST_DEVICE constexpr + void clear() { + CUTE_UNROLL + for (size_type i = 0; i < StorageElements; ++i) { + storage[i] = storage_type(0); + } + } + + CUTE_HOST_DEVICE constexpr + void fill(T const& value) { + CUTE_UNROLL + for (size_type i = 0; i < N; ++i) { + at(i) = value; + } + } + + CUTE_HOST_DEVICE constexpr + reference at(size_type pos) { + return iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference at(size_type pos) const { + return const_iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + reference front() { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + reference back() { + return at(N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const { + return at(N-1); + } + + CUTE_HOST_DEVICE constexpr + pointer data() { + return reinterpret_cast(storage); + } + + CUTE_HOST_DEVICE constexpr + const_pointer data() const { + return reinterpret_cast(storage); + } + + CUTE_HOST_DEVICE constexpr + storage_type* raw_data() { + return storage; + } + + CUTE_HOST_DEVICE constexpr + storage_type const* raw_data() const { + return storage; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() { + return iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const { + return const_iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() { + return iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const { + return const_iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const { + return end(); + } + + // + // Comparison operators + // + +}; + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +void clear(array_subbyte& a) +{ + a.clear(); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array_subbyte& a, T const& value) +{ + a.fill(value); +} + +} // namespace cute + +// +// Specialize tuple-related functionality for cute::array_subbyte +// + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array_subbyte& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array_subbyte const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array_subbyte&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct is_reference> + : CUTE_STL_NAMESPACE::true_type +{}; + + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/server/punica_kernels/include/cutlass/cute/container/bit_field.hpp b/server/punica_kernels/include/cutlass/cute/container/bit_field.hpp new file mode 100644 index 00000000..c5748d84 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/bit_field.hpp @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Portable bit field that supports byte and word straddling that can + be used in unions to bit-wise define parameters. +*/ + +#pragma once + +#include + +#include // uint_bit_t + +namespace cute +{ + +class dummy_type {}; + +template +struct bit_field +{ + static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported."); + + // value_type: Use the smallest value type that fits NumBits + static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 : + (NumBits <= 16) ? 16 : + (NumBits <= 32) ? 32 : 64; + using value_type = cute::uint_bit_t; + // storage_type: Use the smallest storage_type that avoids boundary crossing + static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 : + (BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 : + (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; + using storage_type = cute::uint_bit_t; + + static_assert(sizeof(OtherValueType) == sizeof(value_type) || is_same::value, + "sizeof(OtherValueType) must be same as sizeof(value_type)."); + + // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) + static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits; + // Index of storage value for BitStart + static constexpr uint32_t idx = BitStart / storage_type_bits; + // Bit of data_[idx] for BitStart + static constexpr uint32_t bit_lo = BitStart % storage_type_bits; + // Number of bits in data_[idx] used for NumBits if straddling, else 0 + static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; + +public: + + // NumBits mask + static constexpr value_type mask = value_type(uint64_t(-1) >> (64u - NumBits)); + // NumBits mask for BitStart + static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; + // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 + static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0; + + storage_type data_[N]; + + // Get value + CUTE_HOST_DEVICE constexpr + value_type get() const { + storage_type result = (data_[idx] & mask_lo) >> bit_lo; + if constexpr (bit_hi != 0) { + result |= (data_[idx+1] & mask_hi) << bit_hi; + } + return static_cast(result); + } + + // Set value + CUTE_HOST_DEVICE constexpr + void set(value_type x) { + storage_type item = static_cast(x & mask); + data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); + if constexpr (bit_hi != 0) { + data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); + } + } + + // Assign value + CUTE_HOST_DEVICE constexpr + bit_field& operator=(value_type x) { + set(x); + return *this; + } + + // Cast to value + CUTE_HOST_DEVICE constexpr + operator value_type () const { + return get(); + } + + // Assign OtherValueType + CUTE_HOST_DEVICE constexpr + bit_field& operator=(OtherValueType x) { + return *this = *reinterpret_cast(&x); + } + + // Cast to OtherValueType + CUTE_HOST_DEVICE constexpr + operator OtherValueType () const { + value_type x = get(); + return *reinterpret_cast(&x); + } +}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/container/cuda_types.hpp b/server/punica_kernels/include/cutlass/cute/container/cuda_types.hpp new file mode 100644 index 00000000..8034cb27 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/cuda_types.hpp @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +namespace cute +{ + +// +// dim3 +// + +using dim3 = ::dim3; + +// MSVC doesn't define its C++ version macro to match +// its C++ language version. This means that when +// building with MSVC, dim3 isn't constexpr-friendly. +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t& get(dim3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t const& get(dim3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t&& get(dim3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +// +// uint3 +// + +using uint3 = ::uint3; + +template +CUTE_HOST_DEVICE constexpr +uint32_t& get(uint3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t const& get(uint3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t&& get(uint3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/container/tuple.hpp b/server/punica_kernels/include/cutlass/cute/container/tuple.hpp new file mode 100644 index 00000000..0af98f56 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/tuple.hpp @@ -0,0 +1,720 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::true_type, cute::false_type +#include + +#include + +//#include // Advanced optimizations + +// +// cute::tuple is like std::tuple, with two differences. +// +// 1. It works on both host and device. +// 2. Its template arguments must be semiregular types. +// +// Semiregular types are default constructible and copyable. +// They include "value types" like int or float, +// but do _not_ include references like int& or float&. +// (See std::tie for an example of a tuple of references.) +// +// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of +// the conversion SFINAE, special overloading, and avoiding cvref template types. +// Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding +// construction calls, and ignoring any need for unique element addresses. +// +// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. + +namespace cute +{ + +namespace detail +{ + +// EBO stands for "empty base optimization." +// We use this technique to ensure that cute::tuple +// doesn't need to waste space storing any template arguments +// of cute::tuple that have no data (like integral_constant). +// Otherwise, cute::tuple would need to spend at least 1 byte +// for each of its template arguments. +// +// EBO always "holds" a single value of type T. +// N is like an array index that TupleBase uses +// to access the desired tuple element. +template ::value> +struct EBO; + +template +CUTE_HOST_DEVICE constexpr C findt(EBO const&) +{ return {}; } + +// Specialization for types T that have no data; +// the "static tuple leaf." Valid T here include +// integral_constant, Int, +// and any other semiregular type +// for which std::is_empty_v is true. +template +struct EBO +{ + CUTE_HOST_DEVICE constexpr + EBO() {} + + CUTE_HOST_DEVICE constexpr + EBO(T const&) {} +}; + +template +CUTE_HOST_DEVICE constexpr T getv(EBO const&) +{ return {}; } + +// Specialization for types T that are not empty; +// the "dynamic tuple leaf." Valid T here include int, +// any other integral or floating-point type, +// or any semiregular type for which std::is_empty_v is false. +template +struct EBO +{ + CUTE_HOST_DEVICE constexpr + EBO() : t_{} {} + + template + CUTE_HOST_DEVICE constexpr + EBO(U const& u) : t_{u} {} + + T t_; +}; + +template +CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) +{ return x.t_; } + +template +CUTE_HOST_DEVICE constexpr T& getv(EBO& x) +{ return x.t_; } + +template +CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) +{ return cute::move(x.t_); } + +template +struct TupleBase; + +// Base class of cute::tuple binds each element to an index +// by inheriting from EBO for each (i, t) in (I..., T...). +// The storage (for nonempty t) lives in the base classes. +template +struct TupleBase, T...> + : EBO... +{ + CUTE_HOST_DEVICE constexpr + TupleBase() {} + + template + CUTE_HOST_DEVICE constexpr explicit + TupleBase(U const&... u) + : EBO(u)... {} + + template + CUTE_HOST_DEVICE constexpr + TupleBase(TupleBase, U...> const& u) + : EBO(getv(static_cast const&>(u)))... {} +}; + +} // end namespace detail + +// Attempting to use the following commented-out alias +// in the declaration of `struct tuple` causes MSVC 2022 build errors. +// +//template +//using TupleBase = detail::TupleBase, T...>; + +// This is the actual cute::tuple class. +// The storage (if any) lives in TupleBase's EBO base classes. +// +// Inheriting from the above alias TupleBase +// causes MSVC 2022 build errors when assigning one tuple to another: +// In summary: this is verbose as a work-around for MSVC build errors. +template +struct tuple : detail::TupleBase, T...> +{ + CUTE_HOST_DEVICE constexpr + tuple() {} + + template + CUTE_HOST_DEVICE constexpr + tuple(U const&... u) : detail::TupleBase, T...>(u...) {} + + template + CUTE_HOST_DEVICE constexpr + tuple(tuple const& u) + : detail::TupleBase, T...>(static_cast, U...> const&>(u)) {} +}; + +// +// get for cute::tuple (just like std::get for std::tuple) +// + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple const& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple&& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(static_cast&&>(t)); +} + +// +// find a type X within a cute::tuple +// Requires X to be unique in tuple +// Returns a static integer +// + +template +CUTE_HOST_DEVICE constexpr +auto +find(tuple const& t) noexcept +{ + return detail::findt(t); +} + +// +// Custom is_tuple trait simply checks the existence of tuple_size +// and assumes std::get(.), std::tuple_element +// +namespace detail { + +template +auto has_tuple_size( T*) -> bool_constant<(0 <= tuple_size::value)>; +auto has_tuple_size(...) -> false_type; + +} // end namespace detail + +template +struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; + +// +// make_tuple (value-based implementation) +// + +template +CUTE_HOST_DEVICE constexpr +tuple +make_tuple(T const&... t) +{ + return {t...}; +} + +// +// tuple_cat concatenates multiple cute::tuple into a single cute::tuple, +// just like std::tuple_cat for std::tuple. +// + +#if 0 +// Original implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); +} +#endif + +#if 1 +// Extended implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, + index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, + index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, + index_sequence, index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); +} + +template +struct tuple_cat_static; + +template +struct tuple_cat_static, tuple> { + using type = tuple; +}; + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + if constexpr (is_static::value && is_static::value && + is_tuple::value && is_tuple::value) { + return typename detail::tuple_cat_static::type{}; + } else { + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) +{ + return detail::tuple_cat(t0, t1, t2, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) +{ + return detail::tuple_cat(t0, t1, t2, t3, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) +{ + return detail::tuple_cat(t0, t1, t2, t3, t4, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), cute::tuple_cat(t5, ts...)); +} +#endif + +#if 0 +// Outer-Inner indexing trick to concat all tuples at once + +namespace detail { + +template +struct tuple_cat_helper +{ + static constexpr cute::array ns = {Ns...}; + + static constexpr size_t total_size() { + size_t sum = 0; + for (size_t n : ns) sum += n; + return sum; + } + static constexpr size_t total_size_ = total_size(); + + static constexpr auto values() { + cute::array outer_inner = {}; + + size_t idx = 0; + for (size_t i = 0; i < ns.size(); ++i) { + for (size_t j = 0; j < ns[i]; ++j, ++idx) { + outer_inner[idx][0] = i; + outer_inner[idx][1] = j; + } + } + return outer_inner; + } + static constexpr auto outer_inner_ = values(); + + using total_sequence = make_index_sequence; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuple const& t, index_sequence) +{ + return cute::make_tuple(get(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuples const&... ts) +{ + using Helper = detail::tuple_cat_helper::value...>; + return detail::tuple_cat(cute::make_tuple(ts...), typename Helper::total_sequence{}); +} +#endif + +// +// Equality operators +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +equal_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted + } else { + return (get(a) == get(b)) && equal_impl(a,b); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + return detail::equal_impl<0>(t, u); +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + return cute::false_type{}; +} + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return !(t == u); +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return cute::true_type{}; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] +// -- colexicographical comparison [reverse, reflected, revref] +// -- element-wise comparison [any,all] +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// That said, see int_tuple for more explicitly named common comparison ops. +// + +// +// Display utilities +// + +namespace detail { + +template +CUTE_HOST_DEVICE void print_tuple(Tuple const& t, + index_sequence, char s = '(', char e = ')') +{ + using cute::print; + ((void(print(Is == 0 ? s : ',')), void(print(get(t)))), ...); print(e); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, + index_sequence, char s = '(', char e = ')') +{ + (void(os << (Is == 0 ? s : ',') << get(t)), ...); + return os << e; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace detail + +template ::value)> +CUTE_HOST_DEVICE void print(Tuple const& t) +{ + return detail::print_tuple(t, make_index_sequence::value>{}); +} + +#if !defined(__CUDACC_RTC__) +template ::value)> +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) +{ + return detail::print_tuple_os(os, t, make_index_sequence::value>{}); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +// +// std compatibility +// + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namepsace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/server/punica_kernels/include/cutlass/cute/container/type_list.hpp b/server/punica_kernels/include/cutlass/cute/container/type_list.hpp new file mode 100644 index 00000000..41c499ec --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/container/type_list.hpp @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +namespace cute +{ + +template +struct type_c { + using type = T; +}; + +template +struct type_list {}; + +} // end namespace cute + +// +// Specialize tuple-related functionality for cute::type_list +// + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::tuple_element_t> +get(type_list&) noexcept { + return {}; +} +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::tuple_element_t> +get(type_list const& t) noexcept { + return {}; +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> +{}; + +} // end namespace std + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/server/punica_kernels/include/cutlass/cute/int_tuple.hpp b/server/punica_kernels/include/cutlass/cute/int_tuple.hpp new file mode 100644 index 00000000..110e233a --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/int_tuple.hpp @@ -0,0 +1,946 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +/** IntTuple is an integer or a tuple of IntTuples. + * This file holds utilities for working with IntTuples, + * but does not hold a concrete concept or class of IntTuple. + */ + +namespace cute +{ + +// Implementation of get<0>(Integral). +// Even though is_tuple is false and tuple_size doesn't compile, +// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + static_assert(I == 0, "Index out of range"); + return static_cast(t); +} + +// Custom recursive get for anything that implements get(.) (for a single integer I). +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + return get(get(static_cast(t))); +} + +// +// rank +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int::value>{}; + } else { + return Int<1>{}; + } + } else { + return rank(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using rank_t = decltype(rank(declval())); + +template +static constexpr int rank_v = rank_t::value; + +// +// shape +// + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return transform(s, [](auto const& a) { return shape(a); }); + } else { + return s; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return shape(get(s)); + } else { + return get(shape(s)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max +// + +template +CUTE_HOST_DEVICE constexpr +auto +max(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::max(t0, cute::max(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// min +// + +template +CUTE_HOST_DEVICE constexpr +auto +min(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::min(t0, cute::min(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// gcd +// + +template +CUTE_HOST_DEVICE constexpr +auto +gcd(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::gcd(cute::apply(t0, [](auto const&... a){ return cute::gcd(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::gcd(t0, cute::gcd(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// depth +// + +template +CUTE_HOST_DEVICE constexpr +auto +depth(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); + } else { + return Int<0>{}; + } + } else { + return depth(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using depth_t = decltype(depth(declval())); + +template +static constexpr int depth_v = depth_t::value; + +// +// product +// + +// Implementation of product as a function object +struct Product +{ + template + CUTE_HOST_DEVICE constexpr + auto + operator()(IntTuple const& a) const + { + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 0) { + return Int<1>{}; + } else { + return cute::transform_apply(a, Product{}, multiplies_unary_lfold{}); + } + } else if constexpr (cute::is_integral::value) { + return a; + } + + CUTE_GCC_UNREACHABLE; + } +}; +// Callable product function object +CUTE_INLINE_CONSTANT Product product; + +// Return a rank(t) tuple @a result such that get(@a result) = product(get(@a t)) +template +CUTE_HOST_DEVICE constexpr +auto +product_each(Tuple const& t) +{ + return transform(wrap(t), product); +} + +// Take the product of Tuple at the leaves of TupleG +template +CUTE_HOST_DEVICE constexpr +auto +product_like(Tuple const& tuple, TupleG const& guide) +{ + return transform_leaf(guide, tuple, [](auto const& g, auto const& t) { return product(t); }); +} + +// Return the product of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(IntTuple const& a) +{ + if constexpr (sizeof...(Is) == 0) { + return product(a); + } else { + return size(get(a)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +static constexpr int size_v = decltype(size(declval()))::value; + +// +// sum +// + +template +CUTE_HOST_DEVICE constexpr +auto +sum(IntTuple const& a) +{ + if constexpr (is_tuple::value) { + return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); + } else { + return a; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// inner_product +// + +template +CUTE_HOST_DEVICE constexpr +auto +inner_product(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, + [](auto const&... v) { return (Int<0>{} + ... + v); }); + } else { + return a * b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// ceil_div +// + +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + } else { + return (a + b - Int<1>{}) / b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// round_up +// Round @a a up to the nearest multiple of @a b. +// For negative numbers, rounds away from zero. +// + +template +CUTE_HOST_DEVICE constexpr +auto +round_up(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); }); + } else { + return ((a + b - Int<1>{}) / b) * b; + } + + CUTE_GCC_UNREACHABLE; +} + +/** Division for Shapes + * Case Tuple Tuple: + * Perform shape_div element-wise + * Case Tuple Int: + * Fold the division of b across each element of a + * Example: shape_div((4,5,6),40) -> shape_div((1,5,6),10) -> shape_div((1,1,6),2) -> (1,1,3) + * Case Int Tuple: + * Return shape_div(a, product(b)) + * Case Int Int: + * Enforce the divisibility condition a % b == 0 || b % a == 0 when possible + * Return a / b with rounding away from 0 (that is, 1 or -1 when a < b) + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); + } else { // tuple int + auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return shape_div(a, product(b)); + } else + if constexpr (is_static::value && is_static::value) { + static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure"); + return C{}; + } else { // int int + //assert(a % b == 0 || b % a == 0); // Waive dynamic assertion + return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero + } + + CUTE_GCC_UNREACHABLE; +} + +/** Minimum for Shapes + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_min(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value || is_tuple::value) { + static_assert(dependent_false, "Not implemented."); + } else + if constexpr (is_constant<1, IntTupleA>::value || is_constant<1, IntTupleB>::value) { + return Int<1>{}; // _1 is less than all other shapes, preserve static + } else { + return cute::min(a, b); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Return a tuple the same profile as A scaled by corresponding elements in B + */ +template +CUTE_HOST_DEVICE constexpr +auto +elem_scale(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); + } else { + return a * product(b); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Test if two IntTuple have the same profile (hierarchical rank division) + */ +template +CUTE_HOST_DEVICE constexpr +auto +congruent(IntTupleA const& a, IntTupleB const& b) +{ + return bool_constant::value>{}; +} + +template +using is_congruent = decltype(congruent(declval(), declval())); + +/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division) + * weakly_congruent is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +weakly_congruent(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_congruent(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return true_type{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return weakly_congruent(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_weakly_congruent = decltype(weakly_congruent(declval(), declval())); + +/** Test if Shape A is compatible with Shape B: + * the size of A and B are the same, and + * any coordinate into A can also be used as a coordinate into B + * compatible is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return a == size(b); + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_compatible = decltype(compatible(declval(), declval())); + +/** Test if Shape A is weakly compatible with Shape B: + * there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B) + * weakly_compatible is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +weakly_compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return size(b) % a == Int<0>{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return weakly_compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_weakly_compatible = decltype(weakly_compatible(declval(), declval())); + +/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> + */ +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); + } else if constexpr (is_constant<0, IntTupleA>::value) { + return Int<1>{}; + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tuple const& t) +{ + return filter_zeros(t, t); +} + +// +// Converters and constructors with arrays and params +// + +/** Make an IntTuple of rank N from an Indexable array. + * Access elements up to a dynamic index n, then use init (requires compatible types) + * Consider cute::take if all indexing is known to be valid + * \code + * std::vector a = {6,3,4}; + * auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_int_tuple(Indexable const& t, int n, T const& init) +{ + static_assert(N > 0); + if constexpr (N == 1) { + return 0 < n ? t[0] : init; + } else { + return transform(make_seq{}, [&](auto i) { return i < n ? t[i] : init; }); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Fill the dynamic values of a Tuple with values from another Tuple + * \code + * auto params = make_tuple(6,3,4); + * cute::tuple, cute::tuple>, int, Int<2>> result; + * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +fill_int_tuple_from(Tuple& result, TupleV const& vals) +{ + return fold(result, vals, [](auto const& init, auto&& r) { + if constexpr (is_static>::value) { // Skip static elements of result + return init; + } else if constexpr (is_tuple>::value) { // Recurse into tuples + return fill_int_tuple_from(r, init); + } else { // Assign and consume arg + static_assert(tuple_size>::value > 0, "Not enough values to fill with!"); + r = get<0>(init); + return remove<0>(init); + } + + CUTE_GCC_UNREACHABLE; + }); +} + +/** Make a "Tuple" by filling in the dynamic values in order from the arguments + * \code + * using result_t = cute::tuple, cute::tuple>, int, Int<2>>; + * auto result = make_int_tuple_from(6,3,4); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +Tuple +make_int_tuple_from(Ts const&... ts) +{ + Tuple result = Tuple{}; + fill_int_tuple_from(result, cute::make_tuple(ts...)); + return result; +} + +/** Convert a tuple to a flat homogeneous array of type T + * \code + * auto tup = cute::make_tuple(Int<1>{}, cute::make_tuple(6,3,Int<3>{}),4,Int<2>{}); + * cute::array result = to_array(tup); // [1,6,3,3,4,2] + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +to_array(IntTuple const& t) +{ + auto flat_t = flatten_to_tuple(t); + constexpr int N = tuple_size::value; + cute::array result; + for_each(make_seq{}, [&] (auto i) { result[i] = get(flat_t); }); + return result; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout +// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout +// -- element-wise comparison [any,all] : +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// When actually desiring to order coordinates, the user should map them to +// their indices within the Layout they came from: +// e.g. layoutX(coordA) < layoutX(coordB) +// That said, we implement the three most common ways to compare tuples below. +// These are implemented with slighly more explicit names than op<. +// + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + return lex_less(get(a), get(b)) || (get(a) == get(b) && lex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + constexpr size_t A = tuple_size::value - 1 - I; + constexpr size_t B = tuple_size::value - 1 - I; + return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted + } else { + return elem_less(get(a), get(b)) && elem_less_impl(a,b); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Lexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::lex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_leq(T const& t, U const& u) { + return !lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_gtr(T const& t, U const& u) { + return lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_geq(T const& t, U const& u) { + return !lex_less(t, u); +} + +// Colexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::colex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_leq(T const& t, U const& u) { + return !colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_gtr(T const& t, U const& u) { + return colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_geq(T const& t, U const& u) { + return !colex_less(t, u); +} + +// Elementwise [all] comparison + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::elem_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_leq(T const& t, U const& u) { + return !elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_gtr(T const& t, U const& u) { + return elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_geq(T const& t, U const& u) { + return !elem_less(t, u); +} + +namespace detail { + +/** Increment a (dynamic) coord lexicographically within a shape + * @pre is_congruent::value + * \code + * auto shape = make_shape(1,2,make_shape(2,3),3); + * + * int i = 0; + * for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) { + * std::cout << i++ << ": " << coord << std::endl; + * } + * assert(i == size(shape)); + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape) +{ + if constexpr (is_integral::value) { + ++coord; + } else { + increment(get(coord), get(shape)); + if constexpr (I+1 < tuple_size::value) { + if (back(get(coord)) == back(get(shape))) { + back(get(coord)) = 0; + increment(coord, shape); + } + } + } +} + +} // end namespace detail + +struct ForwardCoordIteratorSentinal +{}; + +// A forward iterator for a starting coordinate in a shape's domain, and a shape. +// The starting coordinate may be zero but need not necessarily be. +template +struct ForwardCoordIterator +{ + static_assert(is_congruent::value); + + CUTE_HOST_DEVICE constexpr + Coord const& operator*() const { return coord; } + + CUTE_HOST_DEVICE constexpr + ForwardCoordIterator& operator++() { detail::increment(coord, shape); return *this; } + + // Sentinel for the end of the implied range + CUTE_HOST_DEVICE constexpr + bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); } + // NOTE: These are expensive, avoid use + CUTE_HOST_DEVICE constexpr + bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); } + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } + + Coord coord; + Shape const& shape; +}; + +// A forward iterator for a coordinate that starts from a provided coordinate +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + return ForwardCoordIterator{coord,shape}; +} + +// A forward iterator for a coordinate that starts from zero +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + auto coord = repeat_like(shape, int(0)); + return make_coord_iterator(coord, shape); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/layout.hpp b/server/punica_kernels/include/cutlass/cute/layout.hpp new file mode 100644 index 00000000..71c4ce13 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/layout.hpp @@ -0,0 +1,1895 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace cute +{ + +// Aliases + +template +using Shape = cute::tuple; + +template +using Stride = cute::tuple; + +template +using Step = cute::tuple; + +template +using Coord = cute::tuple; + +template +using Tile = cute::tuple; + +template +CUTE_HOST_DEVICE constexpr +Shape +make_shape(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Stride +make_stride(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Step +make_step(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Coord +make_coord(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Tile +make_tile(Ts const&... t) +{ + return {t...}; +} + +// +// Layout +// + +template > +struct Layout + : private cute::tuple // EBO for static layouts +{ + // Expensive in compilation time... + //static_assert(is_congruent::value, "Shape and Stride must be congruent"); + + // NOTE: This defaults static Shapes/Strides correctly, but not dynamic + CUTE_HOST_DEVICE constexpr + Layout(Shape const& shape = {}, Stride const& stride = {}) + : cute::tuple(shape, stride) + {} + + // + // Accessors + // + + static constexpr int rank = rank_v; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() { + return get<0,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return get<0,I...>(static_cast const&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() { + return get<1,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return get<1,I...>(static_cast const&>(*this)); + } + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return crd2idx(coord, shape(), stride()); + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // + // Utility + // + + // + // Index to Coordinate + // + + // NOTE: Only valid for compact layouts + + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post congruent(@a result, shape()) + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(IInt const& idx) const { + return cute::idx2crd(idx, shape(), stride()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(IInt const& idx) const { + return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat(Int<1>{})); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post is_integral::value + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(IInt const& idx) const { + return cute::crd2idx(this->get_hier_coord(idx), shape()); + } + + // + // Coordinate to Coordinate + // + +#if 0 + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post congruent(@a result, shape()) + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_hier_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), shape()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_flat_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), product_each(shape())); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post is_integral::value + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_1d_coord(Coord const& crd) const { + //return cute::crd2crd(crd, shape(), product(shape())); + return cute::crd2idx(crd, shape()); + } +#endif +}; + +// Equality, return a static or dynamic boolean +template +CUTE_HOST_DEVICE constexpr +auto +operator==(Layout const& layoutA, Layout const& layoutB) +{ + return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); +} + +template +struct is_layout : false_type {}; +template +struct is_layout> : true_type {}; + +// +// Layout construction +// + +template ::value || is_integral::value) && + (is_tuple::value || is_integral::value))> +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, Stride const& stride) +{ + return Layout(shape, stride); +} + +template ::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape) +{ + return make_layout(shape, compact_col_major(shape)); +} + +// Construct a layout from multiple layouts by +// concatenating each layout as an independent mode +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const&... layouts) +{ + return make_layout(make_shape (layouts.shape()...), + make_stride(layouts.stride()...)); +} + +// +// Convenience tags for common layouts +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, GenColMajor) +{ + return make_layout(shape, compact_col_major(shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, GenRowMajor) +{ + return make_layout(shape, compact_row_major(shape)); +} + +// +// Advanced Layout constructions +// + +// Make a compact layout with shape @a shape and strides following the order induced by @a order. +// Dynamic values in @a order are ignored, considered large, and considered ordered from left to right. +// Example: +// make_ordered_layout(Shape<_2,_2,_2,_2>{}, Step<_0,_2,_3,_1>{}) +// -> (_2,_2,_2,_2):(_1,_4,_8,_2) +// make_ordered_layout(make_shape(2,3,4,5), make_step(Int<2>{}, 67, 42, Int<50>{})) +// -> (2,3,4,5):(_1,10,30,2) +template +CUTE_HOST_DEVICE constexpr +auto +make_ordered_layout(Shape const& shape, Order const& order) +{ + return make_layout(shape, compact_order(shape, order)); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(). +// Static-0 strides in the input @a layout are preserved in the output. +// Example: +// make_layout_like(Layout, Stride<_0,_2,_4,_1>>{}) +// -> (_2,_2,_2,_2):(_0,_2,_4,_1) +// make_layout_like(make_layout(make_shape(2,3,4,5), make_stride(Int<0>{},42,Int<1>{},Int<0>{}))) +// -> (2,3,4,5):(_0,4,_1,_0) +template +CUTE_HOST_DEVICE constexpr +auto +make_layout_like(Layout const& layout) +{ + return make_layout(layout.shape(), + compact_order(filter_zeros(layout.stride(), layout.shape()), layout.stride())); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(), +// except mode-0 is always stride-1 and generated column-major. +// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms +// so this generates the 0th mode with LayoutLeft regardless of the reference layout. +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + constexpr int R = Layout::rank; + if constexpr (R > 1 && is_static::value) { + return tiled_product(make_layout(shape<0>(layout)), + make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); + } else { + return make_layout(layout.shape()); + } + + CUTE_GCC_UNREACHABLE; +} + +template ::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Shape const& shape) +{ + return make_layout(shape); +} + +// +// Make an identity layout that maps a coordinate to itself +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_layout(Shape const& shape) +{ + return make_layout(shape, make_basis_like(shape)); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +// Return the Is...th sublayout. +// For Is... = , equivalent to get(...get(get(layout))) +template +CUTE_HOST_DEVICE constexpr +auto +get(Layout const& layout) +{ + return make_layout(get(layout.shape()), + get(layout.stride())); +} + +// Return a new layout with only the modes in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(Layout const& layout) +{ + static_assert(B < E, "take: empty range error"); + static_assert(0 <= B && E <= Layout::rank, "take: range out of bounds"); + return make_layout(take(layout.shape()), + take(layout.stride())); +} + +// Return a new layout with only the modes Is... = +template +CUTE_HOST_DEVICE constexpr +auto +select(Layout const& layout) +{ + return make_layout(select(layout.shape()), + select(layout.stride())); +} + +// Return a layout with depth at most 1 +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Layout const& layout) +{ + return make_layout(flatten(layout.shape()), + flatten(layout.stride())); +} + +// Return a layout whose profile is congruent to TargetProfile +// @pre Input layout is flat, flatten(@a layout) == @a layout +// @pre Input layout can be folded to profile, rank(@a layout) == rank(flatten(@a target_profile)) +// @post congruent(@a result, @a target_profile) +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(Layout const& layout, TargetProfile const& target_profile) +{ + return make_layout(unflatten(layout.shape(), target_profile), + unflatten(layout.stride(), target_profile)); +} + +// +// Utilities +// + +// Return the sublayout of mode I... +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(Layout const& layout) +{ + if constexpr (sizeof...(Is) == 0) { + return layout; + } else { + return get(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout& layout) +{ + return layout.template shape(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout const& layout) +{ + return layout.template shape(); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout& layout) +{ + return layout.template stride(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout const& layout) +{ + return layout.template stride(); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(Layout const& layout) +{ + return size(shape(layout)); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(Layout const& layout) +{ + return rank(shape(layout)); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(Layout const& layout) +{ + return depth(shape(layout)); +} + +// Return the codomain shape of a mode +// @post size(coshape(@a a)) == cosize(@a a) +// @return C Coordinate with smallest elements such that +// @a elem_less(sub_layout(c), C) for all c < size(@a sub_layout) +// where sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +coshape(Layout const& layout) +{ + // Protect against negative strides + auto abs_sub_layout = make_layout(shape(layout), + transform_leaf(stride(layout), abs_fn{})); + auto co_coord = as_arithmetic_tuple(abs_sub_layout(size(abs_sub_layout) - Int<1>{})); + return co_coord + repeat_like(co_coord, Int<1>{}); +} + +// Return the codomain size of a mode +// @return M smallest integer such that +// @a sub_layout(c) < M for all c < size(@a sub_layout) +// where sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +cosize(Layout const& layout) +{ + return size(coshape(layout)); +} + +template +using cosize_t = decltype(cosize(declval())); + +template +static constexpr int cosize_v = cosize_t::value; + +// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& c, Layout const& layout) +{ + return crd2idx(c, layout.shape(), layout.stride()); +} + +// +// Slice and Dice a layout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& c, Layout const& layout) +{ + return make_layout(slice(c, layout.shape()), + slice(c, layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, Layout const& layout) +{ + return cute::make_tuple(slice(c, layout), crd2idx(c, layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +dice(Coord const& c, Layout const& layout) +{ + return make_layout(dice(c, layout.shape()), + dice(c, layout.stride())); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// This exists so it can be overloaded for ComposedLayout +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Layout const& layout) +{ + return cute::make_tuple(layout, layout(coord)); +} + +// +// Transform the modes of a layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f, seq) +{ + return make_layout(f(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq, seq, seq) +{ + return make_layout(f(get(t0),get(t1))..., get(t0)..., get(t1)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f) +{ + return detail::transform_layout(t, f, make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f) +{ + constexpr int R0 = decltype(rank(t0))::value; + constexpr int R1 = decltype(rank(t1))::value; + constexpr int R = (R0 < R1) ? R0 : R1; + return detail::transform_layout(t0, t1, f, make_seq{}, make_range{}, make_range{}); +} + +// +// Coalesce and Filter +// + +namespace detail { + +// Look at each element and the front of the stack (in order of priority) +// front(NewLayout) get(Layout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_front s1:d1 +// s0:s1*d1 s1:d1 => replace_front s0*s1:d1 +// s0:d0 s1:d1 => prepend s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == -1) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return bw_coalesce(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { + // Merge modes because the shapes and strides match + return bw_coalesce(old_shape, old_stride, + replace_front(new_shape, get(old_shape) * get<0>(new_shape)), + replace_front(new_stride, get(old_stride))); + } else { + // Can't replace or merge, so prepend a new mode + return bw_coalesce(old_shape, old_stride, + prepend(new_shape, get(old_shape)), + prepend(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// "Simplify" the layout by combining modes that are possible to combine +// Does not respect the shape of the layout, but does preserve total size +// @post size(@a result) == size(@a layout) +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i < size(@a layout), @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); +} + +// Apply coalesce at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); }); + } else { + return coalesce(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the modes in layout that have a 0-stride with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout) +{ + return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); +} + +// Remove all of the 0-strides and 1-sizes +// Return 1-shape if empty +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout) +{ + return coalesce(filter_zeros(layout)); +} + +// Apply filter at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); }); + } else { + return filter(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Append, Prepend, Replace +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +replace(Layout const& layout, + Layout const& x) +{ + return make_layout(replace(layout.shape(), x.shape()), + replace(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(Layout const& layout) +{ + return make_layout(group(layout.shape()), + group(layout.stride())); +} + +// +// Composition of two layouts: lhs o rhs +// @post compatible(rhs, result) +// @post result(c) = lhs(rhs(c)) +// for all c in the domain of rhs +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +composition_impl(Layout const& lhs, + RShape const& rhs_shape, RStride const& rhs_stride) +{ + if constexpr (is_tuple::value) { + // Apply the right-distributivity of Layout composition + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition_impl(lhs, s, d); }); + } else + if constexpr (is_scaled_basis::value) { + // Special case for a ScaledBasis stride + return composition_impl(get(lhs), rhs_shape, rhs_stride.value()); + } else + if constexpr (is_integral::value) { + // Integral Rstride (and RShape) + + // NOTE: Should only flatten once for efficiency + auto flat_shape = flatten(lhs.shape()); + [[maybe_unused]] auto flat_stride = flatten(lhs.stride()); + [[maybe_unused]] constexpr int R = rank(flat_shape); + + if constexpr (is_constant<0, RStride>::value) { + // Special case shortcut for any static stride-0 + return Layout{rhs_shape, rhs_stride}; + } else + if constexpr (is_integral::value) { + // Special case shortcut for any integral LShape + auto result_stride = rhs_stride * flat_stride; + return Layout{rhs_shape, result_stride}; + } else + if constexpr (is_constant<1, RStride>::value) { + // Special case shortcut for any static stride-1 + auto result_shape_0 = take<0,R-1>(flat_shape); + + // Mod out the rhs_shape from the lhs.shape() + auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, get(lhs.stride()) + return detail::bw_coalesce(result_shape_1, flat_stride, rest_shape, get(flat_stride)); + } else + { + // General case + auto result_shape_0 = take<0,R-1>(flat_shape); + auto result_stride_0 = take<0,R-1>(flat_stride); + + // Divide out the rhs_stride from the lhs.shape() + auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), + [] (auto const& init, auto const& di) { + return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); + }); + + // Apply any lhs.shape() changes to the stride + auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); + + // Mod out the rhs_shape from the lhs.shape() + auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) + return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(flat_stride)); + } + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Layout const& rhs) +{ + return detail::composition_impl(lhs, rhs.shape(), rhs.stride()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Tiler const& rhs) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + // Drop any modes of lhs that aren't hit by rhs + return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); + } else if constexpr (is_underscore::value) { + return lhs; + } else if constexpr (is_integral::value) { + return detail::composition_impl(lhs, rhs, Int<1>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Complement +// +// Build the complement of a layout. +// @post size(@a result) >= @a cosize_hi / size(filter(@a layout))); +// @post For all i in [1,size(@a result)), +// @a result(i) < @a result(i-1) +// For all j in [0, size(@a layout)), +// @a result(i) != @a layout(j) +// + +namespace detail { + +// @pre @a layout has been filtered (flattened and no stride-0 or size-1 modes). +template +CUTE_HOST_DEVICE constexpr +auto +complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) +{ + if constexpr (is_constant<0, Stride>::value) { + // Special case for irreducible rank-1 stride-0 layout + return make_layout(cosize_hi); + } else { + // General case + constexpr int R = rank_v; + static_assert(R == 1 || is_static::value, + "Dynamic-stride complement only for rank-1 layouts"); + + // Should just be a sort and a fold... + // Then we could even handle dynamic strides (but they would destroy all static strides) + auto [shape_, stride_, result_shape_, result_stride] = + fold(make_seq{}, + cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), + [](auto const& init, auto i) + { + auto [shape, stride, result_shape, result_stride] = init; + auto min_stride = cute::min(stride); + auto min_idx = find(stride, min_stride); + auto new_shape = min_stride / get(result_stride); + auto new_stride = get(shape) * min_stride; + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + + return cute::make_tuple(remove(shape), // Remove the min_idx from shape + remove(stride), // Remove the min_idx from stride + append(result_shape , new_shape ), // new shape = min_stride / last_stride + append(result_stride, new_stride)); // new stride = curr_shape * min_stride + }); + + // Append the last shape mode + auto new_shape = get<0>(stride_) / get(result_stride); + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + auto result_shape = append(result_shape_, new_shape); // new shape = min_stride / last_stride + + // Compute the rest_shape and rest_stride + auto rest_stride = get<0>(shape_) * get<0>(stride_); + auto rest_shape = ceil_div(cosize_hi, rest_stride); + + // Jump into coalesce and append (rest_shape, rest_stride) + return detail::bw_coalesce(result_shape, result_stride, rest_shape, rest_stride); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout, CoSizeHi const& cosize_hi) +{ + static_assert(cute::is_integral::value, "Expected integral codomain size in complement."); + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize_hi); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout) +{ + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize(filter_layout)); +} + +// +// Right-Inverse and Left-Inverse +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +inverse_seq(Shape const& shape, Stride const& stride, seq) +{ + auto next_I = cute::find_if(stride, [](auto a) { return is_constant{}; }); + + if constexpr (next_I == decltype(rank(stride))::value) { + // If not found, return current seq + return seq{}; + } else { + // auto next_stride = get(shape) * get(stride); + // NOTE: Needed for g++-7 + using next_stride = decltype(get(shape) * get(stride)); + + if constexpr (is_static::value && !is_constant::value) { + // If next_stride is static and unique, then continue + return inverse_seq(shape, stride, seq{}); + } else { + // Else return current seq + next_I + return seq{}; + } + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// +// Build the right-inverse of a layout +// @pre is_static +// @result A layout @a result such that +// @a layout(@a result(i)) == i for all i < size(@a result) +// @result A layout @a result such that +// composition(@a layout, @a result) is identical to make_layout(shape(result)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Layout const& layout) +{ + auto flat_layout = coalesce(layout); + auto astride = transform_leaf(flat_layout.stride(), abs_fn{}); + + // Find Int<1>{}, the starting stride, and follow the strides to gen inverse_seq + [[maybe_unused]] auto iseq = detail::inverse_seq<1>(flat_layout.shape(), astride, seq<>{}); + + if constexpr (iseq.size() == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto rstride = compact_col_major(flat_layout.shape()); + return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + unwrap(transform(iseq, [&](auto i) { return signum(stride(flat_layout)) * get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; +} + +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Underscore const& _) +{ + return _; +} + +// +// Build the left-inverse of a layout +// @pre is_static +// @pre @a layout is an injective function +// @result A layout @a result such that +// @a result(@a layout(i)) == i for all i < size(@a layout) +// @result A layout @a result such that +// composition(@a result, @a layout) is identical to make_layout(shape(layout)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Layout const& layout) +{ + return right_inverse(make_layout(layout, complement(layout))); +} + +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Underscore const& _) +{ + return _; +} + +// +// Max Common Layout +// + +/* Return a layout that points to the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Layout R + * @post For all 0 <= i < size(R), a(R(i)) == i and b(R(i)) == i + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Layout const& a, + Layout const& b) +{ + Layout inv_b = right_inverse(b); + Layout common = coalesce(composition(a, inv_b)); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, layout<0>(common)); + } else { + return Layout<_1,_0>{}; + } +} + +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Int with N >= 1 + * @post For all 0 <= n < N, a(b.get_1d_coord(n)) == n + * (NOTE: Problems with negative strides/coords in this post-condition) + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + Layout const& b) +{ + Layout common = coalesce(composition(a, right_inverse(b))); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return shape<0>(common); + } else { + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Kernel (Nullspace) of a Layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace_seq(Stride const& stride, seq) +{ + if constexpr (NextI == rank_v) { + return seq{}; + } else + if constexpr (is_constant<0, decltype(get(stride))>::value) { + return detail::nullspace_seq(stride, seq{}); + } else { + return detail::nullspace_seq(stride, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// +// Build the nullspace of a layout +// @result A layout @a result such that +// size(@a result) == size(@a layout) / size(filter(@a layout)) +// @a layout(@a result(i)) == 0 for all i < size(@a result) +// + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(Layout const& layout) +{ + auto flat_layout = flatten(layout); + + auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{}); + + if constexpr (iseq.size() == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto rstride = compact_col_major(flat_layout.shape()); + return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Zip +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layout) +{ + return make_layout(zip(layout.shape()), + zip(layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layoutA, + Layout const& layoutB) +{ + return make_layout(zip(layoutA.shape(), layoutB.shape()), + zip(layoutA.stride(), layoutB.stride())); +} + +// +// Tile unzip +// Logical product and logical divide (on layouts) produce rank-2 results by design. +// Follow the profile of @a tile and zip the rank-2 modes located at the terminals into +// their own mode. +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(Layout const& layout, + Tiler const& tiler) +{ + return make_layout(zip2_by(layout.shape(), tiler), + zip2_by(layout.stride(), tiler)); +} + +// +// Logical divide +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Layout const& tiler) +{ + return composition(layout, make_layout(tiler, complement(tiler, size(layout)))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tiler."); + return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); }); + } else if constexpr (is_underscore::value) { + return layout; + } else if constexpr (is_integral::value) { + return logical_divide(layout, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Generalization of ceil_div for Layout lhs +// is effectively the "rest mode" of logical_divide. +// Occurs in the calculation of gridDim, for example, for generalized tilers +// Example: +// dim3 gridDim(size(ceil_div(problem_shape_M, cta_tiler_M)), +// size(ceil_div(problem_shape_N, cta_tiler_N))); +// This does not consider compositional acceptance, so it may be the case that +// ceil_div produces a result while logical_divide (and friends) do not. +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(Target const& target, + Layout const& tiler) +{ + return complement(tiler, size(target)); +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the tile modes and residuals into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Layout const& layout, + Tiler const& tiler) +{ + return tile_unzip(logical_divide(layout, tiler), tiler); +} + +// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Logical product +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Layout const& tiler) +{ + return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_product: Too many modes in tiler."); + return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); }); + } else if constexpr (is_underscore::value) { + return block; + } else if constexpr (is_integral::value) { + return logical_product(block, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the block modes and products into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(Layout const& block, + Tiler const& tiler) +{ + return tile_unzip(logical_product(block, tiler), tiler); +} + +// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_product, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Rank-sensitive products +// + +// blocked_product -- Reproduce a block over a tiler. +// Think of every element of "tiler" as a "block" +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return coalesce(zip(get<0>(result), get<1>(result)), tuple_repeat(Int<1>{})); +} + +// raked_product -- Reproduce a block over a tiler with block-interleaving. +// Think of every element of "tiler" as a "block", interleave those blocks, +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return coalesce(zip(get<1>(result), get<0>(result)), tuple_repeat(Int<1>{})); +} + +// tile_to_shape -- Perform a product of a layout so that the result matches a target shape. +// This is similar to blocked_product, but specifies the result shape instead of the +// product shape, which is more convenient in certain circumstances. +// @param block The layout to repeat +// @param trg_shape The target shape of the result +// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with. +// Defaults to GenColMajor, so @a layout will repeat +// across the first mode first, the second mode second, etc +// E.g. Step<_2,_1,_3> will cause @a layout to repeat +// across the second mode first, the first mode second, and the third mode last. +// @pre rank(@a block) <= rank(@a trg_shape) +// @post compatible(@a trg_shape, shape(@a result)) +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(Layout const& block, + TrgShape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); + constexpr int R = rank_v; + + auto padded_block = append(block); + + auto block_shape = product_each(shape(padded_block)); + auto target_shape = product_each(shape(trg_shape)); + + // Assert proper division + if constexpr (is_static::value) { + CUTE_STATIC_ASSERT_V(weakly_compatible(block_shape, target_shape), + "tile_to_shape: block shape does not divide the target shape."); + } + + auto product_shape = ceil_div(target_shape, block_shape); + + return coalesce(blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)), product_shape); +} + +// +// Upcast +// For stride-1 mode, divide size by N. Divide all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { // tuple stride + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride + return Layout{shape,stride}; + } else if constexpr (is_static::value) { // static stride + return make_layout(shape_div(shape, shape_div(Int{}, abs(stride))), + shape_div(stride, Int{})); + } else { // dynamic stride + // assume dynamic strides are larger than N and divisible + // assert(stride % N == 0); + return make_layout(shape, safe_div(stride, Int{})); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Layout const& layout) +{ + return upcast(layout.shape(), layout.stride()); +} + +// +// Downcast +// For stride-1 mode, multiply size by N. Multiply all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast(s,d); }); + } else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) { + return make_layout(shape * Int{}, stride); + } else { + return make_layout(shape, stride * Int{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Layout const& layout) +{ + CUTE_STATIC_ASSERT(has_int1::value, "Downcast requires adjacent elements"); + return downcast(layout.shape(), layout.stride()); +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(Layout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + static_assert(dependent_false, "Recast not supported."); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Layout const& layout) +{ + print(layout.shape()); print(":"); print(layout.stride()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) +{ + return os << shape(layout) << ":" << stride(layout); +} +#endif + +// Generic 2D Layout to console table +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout) // (m,n) -> idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + int idx_width = num_digits(cosize(layout)) + 2; + const char* delim = "+-----------------------"; + + print(layout); print("\n"); + + // Column indices + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } + printf("\n"); + + // Print out A m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); + // Values + printf("%2d ", m); // Row indices + for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } + printf("|\n"); + } + // Footer + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); +} + +// Generic ThrVal 2D Layout to console table +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + print(layout); print("\n"); + print(thrid); print("\n"); + + // Print out m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + for (int n = 0; n < size<1>(layout); ++n) printf("+------"); + printf("+\n"); + // Values + for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid))); + printf("|\n"); + } + // Footer + for (int n = 0; n < size<1>(layout); ++n) printf("+------"); + printf("+\n"); +} + +// Generic 2D Layout to Latex printer -- B&W 8-value color coding +template +CUTE_HOST_DEVICE +void +print_latex(LayoutA const& layout_a) +{ + CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); + auto layout = append<2>(layout_a, Layout<_1,_0>{}); + + char const* latex_header = + "\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center,font=\\Large}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"black!00", + "black!40", + "black!20", + "black!60", + "black!10", + "black!50", + "black!30", + "black!70"}; + + // Header + printf("%% Layout: "); print(layout); printf("\n"); + + printf(latex_header); + + // Layout + for (int i = 0; i < size<0>(layout); ++i) { + for (int j = 0; j < size<1>(layout); ++j) { + int idx = layout(i,j); + printf("\\node[box,fill=%s] at (%d,%d) {%d};\n", + color_map[idx % 8], + i, j, + idx); + } + } + + // Labels + for (int i = 0, j = -1; i < size<0>(layout); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int j = 0, i = -1; j < size<1>(layout); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + + // Footer + printf(latex_footer); +} + +// Generic ThrVal 2D Layout to Latex TIKZ -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + char const* latex_header = + "\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const* latex_footer = + "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + + // Header + printf("%% layout: "); print(layout); printf("\n"); + printf("%% thrid: "); print(thr); printf("\n\n"); + + printf(latex_header); + + // Layout + for (int i = 0; i < size<0>(layout); ++i) { + for (int j = 0; j < size<1>(layout); ++j) { + int thrid = layout(i,j) % size(thr); + int val_idx = layout(i,j) / size(thr); + int thr_idx = thr(thrid); + + printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color_map[thr_idx % 8], + i, j, + thr_idx, val_idx); + } + } + + // Labels + for (int i = 0, j = -1; i < size<0>(layout); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int j = 0, i = -1; j < size<1>(layout); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + + // Footer + printf(latex_footer); +} + +} // end namespace cute + +// +// Extended Layouts +// + +#include diff --git a/server/punica_kernels/include/cutlass/cute/layout_composed.hpp b/server/punica_kernels/include/cutlass/cute/layout_composed.hpp new file mode 100644 index 00000000..21845a4e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/layout_composed.hpp @@ -0,0 +1,640 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +/* This implements a ComposedLayout of the form + * LayoutA o Offset o LayoutB + * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. + * For example, when the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB). + * + * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. + * For example, this layout provides shape() and size() functions, but does not provide stride() functions. + * Mostly, the similar functionality is accomplished by applying each operation to LayoutB only + * as LayoutB defines the domain. + */ + +namespace cute +{ + +// A Layout of non-trivially composable functions: F o I o L +template +struct ComposedLayout : private cute::tuple // EBO for static layouts +{ + CUTE_HOST_DEVICE constexpr + ComposedLayout(LayoutA const& layoutA = {}, + Offset const& offset = {}, + LayoutB const& layoutB = {}) + : cute::tuple(layoutA, offset, layoutB) + {} + + // + // Accessors + // + + static constexpr int rank = LayoutB::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_a() const { + return get<0>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + offset() const { + return get<1>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_b() const { + return get<2>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout_b().shape(); + } + + // Doesn't really make sense to ask for the strides of this "layout" + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const = delete; + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return layout_a()(offset() + layout_b()(coord)); // (A o O o B)(c) + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // Equality, return a static or dynamic boolean + template + CUTE_HOST_DEVICE constexpr + auto + operator==(ComposedLayout const& other) const { + return this->layout_a() == other.layout_a() && + this->layout_b() == other.layout_b() && + this->offset() == other.offset(); + } +}; + +template +struct is_layout> : true_type {}; + +template +struct is_composed_layout : false_type {}; +template +struct is_composed_layout> : true_type {}; + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_composed_layout(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +// +// Utilities +// + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(ComposedLayout const& clayout) +{ + return composition(clayout.layout_a(), clayout.offset(), layout(clayout.layout_b())); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(ComposedLayout const& layout) +{ + return shape(layout.layout_b()); +} + +// Doesn't make sense to directly ask for the strides of this "layout" +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(ComposedLayout const& layout) = delete; + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(ComposedLayout const& layout) +{ + return size(layout.layout_b()); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(ComposedLayout const& layout) +{ + return rank(layout.layout_b()); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(ComposedLayout const& layout) +{ + return depth(layout.layout_b()); +} + +// Return the codomain size of a mode +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout const& layout) +{ + return cosize(layout.layout_b()); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +template +CUTE_HOST_DEVICE constexpr +auto +get(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), get(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), take(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), flatten(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(ComposedLayout const& a, X const& x) +{ + return composition(a.layout_a(), a.offset(), append(a.layout_b(), x)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), group(a.layout_b())); +} + +// +// Slice a ComposedLayout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout const& layout) +{ + auto [slice, offset] = slice_and_offset(coord, layout.layout_b()); + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + offset, slice}, Int<0>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& coord, ComposedLayout const& layout) +{ + return get<0>(slice_and_offset(coord, layout)); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// For composed layout tensors the offset is accumulated in the layout itself while pointer is not updated +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, ComposedLayout const& layout) +{ + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + layout.layout_b()(coord), layout.layout_b()}, Int<0>{}); +} + +// +// composition +// + +template +CUTE_HOST_DEVICE constexpr +auto +composition(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + ComposedLayout const& b) +{ + CUTE_STATIC_ASSERT_V(b.offset() == Int<0>{}, "Require offset == 0."); + + return composition(composition(a, b.layout_a()), b.layout_b()); +} + +// +// complement +// + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) +{ + return complement(layout.layout_b(), cosize_hi); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout) +{ + return complement(layout, cosize(layout)); +} + +// +// inverse +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout const& layout) +{ + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout const& layout) +{ + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), zip(a.layout_b())); +} + +// Partitions + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(ComposedLayout const& layout, + Shape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + return composition(layout.layout_a(), layout.offset(), tile_to_shape(layout.layout_b(), trg_shape, ord_shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), filter(layout.layout_b(), trg_profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile)); +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout const& layout) +{ + return composition(upcast(layout.layout_a()), upcast(layout.offset()), upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout const& layout) +{ + return composition(downcast(layout.layout_a()), downcast(layout.offset()), downcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(ComposedLayout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + static_assert(dependent_false, "Recast not supported."); + } + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ComposedLayout const& layout) +{ + print(layout.layout_a()); print(" o "); print(layout.offset()); print(" o "); print(layout.layout_b()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) +{ + return os << layout.layout_a() << " o " << layout.offset() << " o " << layout.layout_b(); +} +#endif + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/arithmetic_tuple.hpp b/server/punica_kernels/include/cutlass/cute/numeric/arithmetic_tuple.hpp new file mode 100644 index 00000000..27d1cf8e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/arithmetic_tuple.hpp @@ -0,0 +1,607 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace cute +{ + +template +struct ArithmeticTuple : tuple +{ + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(ArithmeticTuple const& u) + : tuple(static_cast const&>(u)) {} + + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(tuple const& u) + : tuple(u) {} + + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(U const&... u) + : tuple(u...) {} +}; + +template +struct is_tuple> : true_type {}; + +template +struct is_flat> : is_flat> {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_arithmetic_tuple(T const&... t) { + return ArithmeticTuple(t...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(tuple const& t) { + return ArithmeticTuple(t); +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +T const& +as_arithmetic_tuple(T const& t) { + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ArithmeticTuple const& t) { + return t; +} + +// +// Numeric operators +// + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, tuple const& u) { + return t + ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) + u; +} + +// Subtraction +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), minus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, tuple const& u) { + return t - ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) - u; +} + +// Negation +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t) { + return transform_apply(t, negate{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +// +// Special cases +// + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator+(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op+ error!"); + return u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator+(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op+ error!"); + return t; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator-(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op- error!"); + return -u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator-(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op- error!"); + return t; +} + +// +// ArithmeticTupleIterator +// + +template +struct ArithmeticTupleIterator +{ + using value_type = ArithTuple; + using element_type = ArithTuple; + using reference = ArithTuple; + + ArithTuple coord_; + + CUTE_HOST_DEVICE constexpr + ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} + + CUTE_HOST_DEVICE constexpr + ArithTuple const& operator*() const { return coord_; } + + template + CUTE_HOST_DEVICE constexpr + auto operator[](Coord const& c) const { return *(*this + c); } + + template + CUTE_HOST_DEVICE constexpr + auto operator+(Coord const& c) const { + return ArithmeticTupleIterator(coord_ + c); + } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(Tuple const& t) { + return ArithmeticTupleIterator(as_arithmetic_tuple(t)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) { + return make_inttuple_iter(cute::make_tuple(t0, t1, ts...)); +} + +// +// ArithmeticTuple "basis" elements +// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: +// (_0,_0,...,T,_0,...) +// with value T in the Nth mode + +template +struct ScaledBasis : private tuple +{ + CUTE_HOST_DEVICE constexpr + ScaledBasis(T const& t = {}) : tuple(t) {} + + CUTE_HOST_DEVICE constexpr + decltype(auto) value() { return get<0>(static_cast &>(*this)); } + CUTE_HOST_DEVICE constexpr + decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } + + CUTE_HOST_DEVICE static constexpr + auto mode() { return Int{}; } +}; + +template +struct is_scaled_basis : false_type {}; +template +struct is_scaled_basis> : true_type {}; + +template +struct is_integral> : true_type {}; + +// Get the scalar T out of a ScaledBasis +template +CUTE_HOST_DEVICE constexpr auto +basis_value(SB const& e) +{ + if constexpr (is_scaled_basis::value) { + return basis_value(e.value()); + } else { + return e; + } + CUTE_GCC_UNREACHABLE; +} + +// Apply the N... pack to another Tuple +template +CUTE_HOST_DEVICE constexpr auto +basis_get(SB const& e, Tuple const& t) +{ + if constexpr (is_scaled_basis::value) { + return basis_get(e.value(), get(t)); + } else { + return t; + } + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +struct Basis; + +template <> +struct Basis<> { + using type = Int<1>; +}; + +template +struct Basis { + using type = ScaledBasis::type, N>; +}; + +} // end namespace detail + +// Shortcut for writing ScaledBasis, N0>, N1>, ...> +// E<> := _1 +// E<0> := (_1,_0,_0,...) +// E<1> := (_0,_1,_0,...) +// E<0,0> := ((_1,_0,_0,...),_0,_0,...) +// E<0,1> := ((_0,_1,_0,...),_0,_0,...) +// E<1,0> := (_0,(_1,_0,_0,...),_0,...) +// E<1,1> := (_0,(_0,_1,_0,...),_0,...) +template +using E = typename detail::Basis::type; + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(T const& t, seq, seq) { + return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ArithmeticTuple const& t, seq, seq) { + return make_arithmetic_tuple(get(t)..., (void(J),Int<0>{})...); +} + +} // end namespace detail + +// Turn a ScaledBases into a rank-M ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + static_assert(M > N, "Mismatched ranks"); + return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); +} + +// Turn a ScaledBases into a rank-N ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return as_arithmetic_tuple(t); +} + +// Turn an ArithmeticTuple into a rank-M ArithmeticTuple +// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ArithmeticTuple const& t) { + static_assert(M >= sizeof...(T), "Mismatched ranks"); + return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(ScaledBasis const& b, U const& u) +{ + auto t = safe_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(ScaledBasis const& b, U const& u) +{ + auto t = shape_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_basis_like(Shape const& shape) +{ + if constexpr (is_integral::value) { + return Int<1>{}; + } + else { + // Generate bases for each rank of shape + return transform(tuple_seq{}, shape, [](auto I, auto si) { + // Generate bases for each rank of si and add an i on front + using I_type = decltype(I); + return transform_leaf(make_basis_like(si), [](auto e) { + // MSVC has trouble capturing variables as constexpr, + // so that they can be used as template arguments. + // This is exactly what the code needs to do with i, unfortunately. + // The work-around is to define i inside the inner lambda, + // by using just the type from the enclosing scope. + constexpr int i = I_type::value; + return ScaledBasis{}; + }); + }); + } + + CUTE_GCC_UNREACHABLE; +} + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(ScaledBasis const& t, ScaledBasis const& u) { + return bool_constant{} && t.value() == u.value(); +} + +// Not equal to anything else +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(ScaledBasis const&, U const&) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(T const&, ScaledBasis const&) { + return {}; +} + +// Abs +template +CUTE_HOST_DEVICE constexpr +auto +abs(ScaledBasis const& e) { + return ScaledBasis{abs(e.value())}; +} + +// Multiplication +template +CUTE_HOST_DEVICE constexpr +auto +operator*(A const& a, ScaledBasis const& e) { + auto r = a * e.value(); + return ScaledBasis{r}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator*(ScaledBasis const& e, B const& b) { + auto r = e.value() * b; + return ScaledBasis{r}; +} + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(N+1, int(sizeof...(U))); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { + constexpr int R = cute::max(int(sizeof...(T)), M+1); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, tuple const& u) { + constexpr int R = cute::max(N+1, int(sizeof...(U))); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(tuple const& t, ScaledBasis const& u) { + constexpr int R = cute::max(int(sizeof...(T)), M+1); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ScaledBasis const& u) { + constexpr int R = cute::max(N+1,M+1); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(C, ScaledBasis const& u) { + static_assert(t == 0, "ScaledBasis op+ error!"); + return u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, C) { + static_assert(u == 0, "ScaledBasis op+ error!"); + return t; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) +{ + printf("ArithTuple"); print(iter.coord_); +} + +template +CUTE_HOST_DEVICE void print(ScaledBasis const& e) +{ + print(e.value()); printf("@%d", N); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator const& iter) +{ + return os << "ArithTuple" << iter.coord_; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) +{ + return os << e.value() << "@" << N; +} +#endif + +} // end namespace cute + + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/server/punica_kernels/include/cutlass/cute/numeric/complex.hpp b/server/punica_kernels/include/cutlass/cute/numeric/complex.hpp new file mode 100644 index 00000000..8cc36253 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/complex.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +namespace cute +{ + +using cutlass::complex; +using cutlass::is_complex; +using cutlass::RealType; +using cutlass::real; +using cutlass::imag; +using cutlass::conj; + +template +static constexpr auto is_complex_v = is_complex::value; + +/// Fused multiply-add for complex numbers +template +CUTE_HOST_DEVICE constexpr +void +fma(complex & d, + complex const& a, + complex const& b, + complex const& c) +{ + d.real(fma( a.real(), b.real(), c.real())); + d.imag(fma( a.real(), b.imag(), c.imag())); + d.real(fma(-a.imag(), b.imag(), d.real())); + d.imag(fma( a.imag(), b.real(), d.imag())); +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(complex const& a, + complex const& b, + complex & c) +{ + return fma(c, a, b, c); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/int.hpp b/server/punica_kernels/include/cutlass/cute/numeric/int.hpp new file mode 100644 index 00000000..169e3a0e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/int.hpp @@ -0,0 +1,104 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include + +namespace cute +{ + +// +// Signed integers +// + +using int2_t = cutlass::int2b_t; +using int4_t = cutlass::int4b_t; +using CUTE_STL_NAMESPACE::int8_t; +using CUTE_STL_NAMESPACE::int16_t; +using CUTE_STL_NAMESPACE::int32_t; +using CUTE_STL_NAMESPACE::int64_t; + +template struct int_bit; +template <> struct int_bit< 2> { using type = cutlass::int2b_t; }; +template <> struct int_bit< 4> { using type = cutlass::int4b_t; }; +template <> struct int_bit< 8> { using type = int8_t; }; +template <> struct int_bit< 16> { using type = int16_t; }; +template <> struct int_bit< 32> { using type = int32_t; }; +template <> struct int_bit< 64> { using type = int64_t; }; + +template +using int_bit_t = typename int_bit::type; + +template +using int_byte = int_bit<8*N>; + +template +using int_byte_t = typename int_byte::type; + +// +// Unsigned integers +// + +using uint1_t = cutlass::uint1b_t; +using uint2_t = cutlass::uint2b_t; +using uint4_t = cutlass::uint4b_t; +using CUTE_STL_NAMESPACE::uint8_t; +using CUTE_STL_NAMESPACE::uint16_t; +using CUTE_STL_NAMESPACE::uint32_t; +using CUTE_STL_NAMESPACE::uint64_t; +using cutlass::uint128_t; + +template struct uint_bit; +template <> struct uint_bit< 1> { using type = cutlass::uint1b_t; }; +template <> struct uint_bit< 2> { using type = cutlass::uint2b_t; }; +template <> struct uint_bit< 4> { using type = cutlass::uint4b_t; }; +template <> struct uint_bit< 8> { using type = uint8_t; }; +template <> struct uint_bit< 16> { using type = uint16_t; }; +template <> struct uint_bit< 32> { using type = uint32_t; }; +template <> struct uint_bit< 64> { using type = uint64_t; }; +template <> struct uint_bit<128> { using type = cutlass::uint128_t; }; + +template +using uint_bit_t = typename uint_bit::type; + +template +using uint_byte = uint_bit<8*N>; + +template +using uint_byte_t = typename uint_byte::type; + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/integer_sequence.hpp b/server/punica_kernels/include/cutlass/cute/numeric/integer_sequence.hpp new file mode 100644 index 00000000..60801795 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/integer_sequence.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +namespace cute +{ + +using CUTE_STL_NAMESPACE::integer_sequence; +using CUTE_STL_NAMESPACE::make_integer_sequence; + +namespace detail { + +template +struct range_impl; + +template +struct range_impl, Begin> { + using type = integer_sequence; +}; + +template +struct reverse_impl; + +template +struct reverse_impl> { + using type = integer_sequence; +}; + +} // end namespace detail + +template +using make_integer_range = typename detail::range_impl< + T, + make_integer_sequence 0) ? (End-Begin) : 0>, + Begin>::type; + +template +using make_integer_sequence_reverse = typename detail::reverse_impl< + make_integer_sequence>::type; + +// +// Common aliases +// + +// int_sequence + +template +using int_sequence = integer_sequence; + +template +using make_int_sequence = make_integer_sequence; + +template +using make_int_rsequence = make_integer_sequence_reverse; + +template +using make_int_range = make_integer_range; + +// index_sequence + +template +using index_sequence = integer_sequence; + +template +using make_index_sequence = make_integer_sequence; + +template +using make_index_rsequence = make_integer_sequence_reverse; + +template +using make_index_range = make_integer_range; + +// +// Shortcuts +// + +template +using seq = int_sequence; + +template +using make_seq = make_int_sequence; + +template +using make_rseq = make_int_rsequence; + +template +using make_range = make_int_range; + +template +using tuple_seq = make_seq>::value>; + +template +using tuple_rseq = make_rseq>::value>; + +// +// Specialize cute::tuple-traits for std::integer_sequence +// + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + constexpr static T idx[sizeof...(Is)] = {Is...}; + using type = cute::integral_constant; +}; + +template +CUTE_HOST_DEVICE constexpr +tuple_element_t> +get(integer_sequence) { + static_assert(I < sizeof...(Ints), "Index out of range"); + return {}; +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/integral_constant.hpp b/server/punica_kernels/include/cutlass/cute/numeric/integral_constant.hpp new file mode 100644 index 00000000..8b74aae4 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/integral_constant.hpp @@ -0,0 +1,477 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/util/print.hpp" +#include "cute/util/type_traits.hpp" +#include "cute/numeric/math.hpp" + +namespace cute +{ + +// A constant value: short name and type-deduction for fast compilation +template +struct C { + using type = C; + static constexpr auto value = v; + using value_type = decltype(v); + CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } +}; + +// Deprecate +template +using constant = C; + +template +using bool_constant = C; + +using true_type = bool_constant; +using false_type = bool_constant; + +// A more std:: conforming integral_constant that enforces type but interops with C +template +struct integral_constant : C { + using type = integral_constant; + static constexpr T value = v; + using value_type = T; + // Disambiguate C::operator value_type() + //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } +}; + +// +// Traits +// + +// Use cute::is_std_integral to match built-in integral types (int, int64_t, unsigned, etc) +// Use cute::is_integral to match both built-in integral types AND static integral types. + +template +struct is_integral : bool_constant::value> {}; +template +struct is_integral > : true_type {}; +template +struct is_integral> : true_type {}; + +// is_static detects if an (abstract) value is defined completely by it's type (no members) + +template +struct is_static : bool_constant>::value> {}; + +template +constexpr bool is_static_v = is_static::value; + +// is_constant detects if a type is a static integral type and if v is equal to a value + +template +struct is_constant : false_type {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant > : bool_constant {}; +template +struct is_constant> : bool_constant {}; + +// +// Specializations +// + +template +using Int = C; + +using _m32 = Int<-32>; +using _m24 = Int<-24>; +using _m16 = Int<-16>; +using _m12 = Int<-12>; +using _m10 = Int<-10>; +using _m9 = Int<-9>; +using _m8 = Int<-8>; +using _m7 = Int<-7>; +using _m6 = Int<-6>; +using _m5 = Int<-5>; +using _m4 = Int<-4>; +using _m3 = Int<-3>; +using _m2 = Int<-2>; +using _m1 = Int<-1>; +using _0 = Int<0>; +using _1 = Int<1>; +using _2 = Int<2>; +using _3 = Int<3>; +using _4 = Int<4>; +using _5 = Int<5>; +using _6 = Int<6>; +using _7 = Int<7>; +using _8 = Int<8>; +using _9 = Int<9>; +using _10 = Int<10>; +using _12 = Int<12>; +using _16 = Int<16>; +using _24 = Int<24>; +using _32 = Int<32>; +using _64 = Int<64>; +using _96 = Int<96>; +using _128 = Int<128>; +using _192 = Int<192>; +using _256 = Int<256>; +using _384 = Int<384>; +using _512 = Int<512>; +using _768 = Int<768>; +using _1024 = Int<1024>; +using _2048 = Int<2048>; +using _4096 = Int<4096>; +using _8192 = Int<8192>; +using _16384 = Int<16384>; +using _32768 = Int<32768>; +using _65536 = Int<65536>; +using _131072 = Int<131072>; +using _262144 = Int<262144>; +using _524288 = Int<524288>; + +/***************/ +/** Operators **/ +/***************/ + +#define CUTE_LEFT_UNARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C<(OP t)> operator OP (C) { \ + return {}; \ + } +#define CUTE_RIGHT_UNARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C<(t OP)> operator OP (C) { \ + return {}; \ + } +#define CUTE_BINARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C<(t OP u)> operator OP (C, C) { \ + return {}; \ + } + +CUTE_LEFT_UNARY_OP(+); +CUTE_LEFT_UNARY_OP(-); +CUTE_LEFT_UNARY_OP(~); +CUTE_LEFT_UNARY_OP(!); +CUTE_LEFT_UNARY_OP(*); + +CUTE_BINARY_OP( +); +CUTE_BINARY_OP( -); +CUTE_BINARY_OP( *); +CUTE_BINARY_OP( /); +CUTE_BINARY_OP( %); +CUTE_BINARY_OP( &); +CUTE_BINARY_OP( |); +CUTE_BINARY_OP( ^); +CUTE_BINARY_OP(<<); +CUTE_BINARY_OP(>>); + +CUTE_BINARY_OP(&&); +CUTE_BINARY_OP(||); + +CUTE_BINARY_OP(==); +CUTE_BINARY_OP(!=); +CUTE_BINARY_OP( >); +CUTE_BINARY_OP( <); +CUTE_BINARY_OP(>=); +CUTE_BINARY_OP(<=); + +#undef CUTE_BINARY_OP +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP + +// +// Mixed static-dynamic special cases +// + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator*(C, U) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator*(U, C) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator/(C, U) { + return {}; +} + +template ::value && (t == 1 || t == -1))> +CUTE_HOST_DEVICE constexpr +C<0> +operator%(U, C) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator%(C, U) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator&(C, U) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator&(U, C) { + return {}; +} + +template ::value && !bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator&&(C, U) { + return {}; +} + +template ::value && !bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator&&(U, C) { + return {}; +} + +template ::value && bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator||(C, U) { + return {}; +} + +template ::value && bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator||(U, C) { + return {}; +} + +// +// Named functions from math.hpp +// + +#define CUTE_NAMED_UNARY_FN(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C OP (C) { \ + return {}; \ + } +#define CUTE_NAMED_BINARY_FN(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C OP (C, C) { \ + return {}; \ + } \ + template ::value)> \ + CUTE_HOST_DEVICE constexpr \ + auto OP (C, U u) { \ + return OP(t,u); \ + } \ + template ::value)> \ + CUTE_HOST_DEVICE constexpr \ + auto OP (T t, C) { \ + return OP(t,u); \ + } + +CUTE_NAMED_UNARY_FN(abs); +CUTE_NAMED_UNARY_FN(signum); +CUTE_NAMED_UNARY_FN(has_single_bit); + +CUTE_NAMED_BINARY_FN(max); +CUTE_NAMED_BINARY_FN(min); +CUTE_NAMED_BINARY_FN(shiftl); +CUTE_NAMED_BINARY_FN(shiftr); +CUTE_NAMED_BINARY_FN(gcd); +CUTE_NAMED_BINARY_FN(lcm); + +#undef CUTE_NAMED_UNARY_FN +#undef CUTE_NAMED_BINARY_FN + +// +// Other functions +// + +template +CUTE_HOST_DEVICE constexpr +C +safe_div(C, C) { + static_assert(t % u == 0, "Static safe_div requires t % u == 0"); + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(C, U u) { + return t / u; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(T t, C) { + return t / u; +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +conditional_return(true_type, TrueType&& t, FalseType&&) { + return static_cast(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +conditional_return(false_type, TrueType&&, FalseType&& f) { + return static_cast(f); +} + +// TrueType and FalseType must have a common type +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, TrueType const& t, FalseType const& f) { + return b ? t : f; +} + +// TrueType and FalseType don't require a common type +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(TrueType const& t, FalseType const& f) { + if constexpr (b) { + return t; + } else { + return f; + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +static_value() +{ + if constexpr (is_std_integral::value) { + return Int{}; + } else { + return Trait::value; + } + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(C) { + printf("_"); + ::cute::print(Value); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, C const&) { + return os << "_" << t; +} +#endif + + +namespace detail { + +// parse_int_digits takes a variadic number of digits and converts them into an int +template +constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits) +{ + if constexpr (sizeof...(Ts) == 0) { + return 10 * result + digit; + } else { + return parse_int_digits(10 * result + digit, digits...); + } +} + +} // end namespace detail + + +// This user-defined literal operator allows cute::constant written as literals. For example, +// +// auto var = 32_c; +// +// var has type cute::constant. +// +template +constexpr cute::constant operator "" _c() +{ + static_assert((('0' <= digits && digits <= '9') && ...), + "Expected 0 <= digit <= 9 for each digit of the integer."); + return {}; +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/integral_ratio.hpp b/server/punica_kernels/include/cutlass/cute/numeric/integral_ratio.hpp new file mode 100644 index 00000000..943b0049 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/integral_ratio.hpp @@ -0,0 +1,265 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +namespace cute +{ + +/** Compile-time rational arithmetic type. + * Like cute::C for std::integral_constant, cute::R for std::ratio has a short name + * for error messages and compile times. + * The static data members @a num and @a den represent the reduced numerator and denominator + * of the rational value. Thus, two cute::R types with different @a n or @a d are distinct types + * even if they represent the same rational value. + * A cute::R exposes the reduced canonical type via its ::type member. + * That is, cute::R<3,6>::type is cute::R<1,2> and cute::R<6,3>::type is cute::C<2>. + * A cute::R::value can be used much like any other trait::value. It can be involved in + * arithmetic expressions (according to the operator-overloads for cute::C and cute::R, + * though these may be incomplete) but with a potential rational value rather than an integral value. + */ +template +class R { + static_assert(d != 0); + static constexpr auto an = abs(n); + static constexpr auto ad = abs(d); + static constexpr auto g = gcd(an, ad); + + public: + static constexpr auto num = signum(n) * signum(d) * an / g; + static constexpr auto den = ad / g; + // RI: den >= 1 && gcd(abs(num),den) == 1 + using type = typename conditional, R>::type; +}; + +template +struct is_ratio : false_type {}; +template +struct is_ratio> : true_type {}; + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(C, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(R, R) { + return {}; +} + +// +// Non-reduced ratio implementations +// + +template +CUTE_HOST_DEVICE constexpr +R +nratio(C, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(R, R) { + return {}; +} + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator/(C, R) { + return {}; +} + +// Product with dynamic type needs to produce an integer... +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(C const& c, R) { + return c * R::num / R::den; +} + +// Product with dynamic type needs to produce an integer... +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(R, C const& c) { + return c * R::num / R::den; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == R::num && R::den == R::den> +operator==(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == c && R::den == 1> +operator==(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == c && R::den == 1> +operator==(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +abs(R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +int32_t +log_2(R) { + static_assert(R::num > 0); + static_assert(R::den > 0); + return log_2(static_cast(R::num)) - log_2(static_cast(R::den)); +} + +// @return A non-reduced ratio cute::R of the Trait0::value / Trait1::value +template +CUTE_HOST_DEVICE constexpr +auto +trait_ratio(Trait0, Trait1) { + return nratio(static_value(), static_value()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(R) { + print(C{}); print("/"); print(C{}); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, R) { + return os << "_" << C{} << "/" << C{}; +} +#endif + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/math.hpp b/server/punica_kernels/include/cutlass/cute/numeric/math.hpp new file mode 100644 index 00000000..5be50339 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/math.hpp @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +// +// Common Operations +// + +template ::value && + is_arithmetic::value)> +CUTE_HOST_DEVICE constexpr +auto +max(T const& t, U const& u) { + return t < u ? u : t; +} + +template ::value && + is_arithmetic::value)> +CUTE_HOST_DEVICE constexpr +auto +min(T const& t, U const& u) { + return t < u ? t : u; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +abs(T const& t) { + if constexpr (is_signed::value) { + return t < T(0) ? -t : t; + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. +template ::value)> +CUTE_HOST_DEVICE constexpr +int +signum(T const& x) { + if constexpr (is_signed::value) { + return (T(0) < x) - (x < T(0)); + } else { + return T(0) < x; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// C++17 operations +// + +// Greatest common divisor of two positive integers +template ::value && + is_std_integral::value)> +CUTE_HOST_DEVICE constexpr +cute::common_type_t +gcd(T t, U u) { + while (true) { + if (t == 0) { return u; } + u %= t; + if (u == 0) { return t; } + t %= u; + } +} + +// Least common multiple of two positive integers +template ::value && + is_std_integral::value)> +CUTE_HOST_DEVICE constexpr +cute::common_type_t +lcm(T const& t, U const& u) { + return (t / gcd(t,u)) * u; +} + +// +// C++20 operations +// + +// Checks if a number is an integral power of two +template +CUTE_HOST_DEVICE constexpr +bool +has_single_bit(T x) { + return x != 0 && (x & (x - 1)) == 0; +} + +// Smallest number of bits needed to represent the given value +// For x == 0, this is 0 +// For x != 0, this is 1 + floor(log2(x)) +// bit_width( 0b0000 ) = 0 +// bit_width( 0b0001 ) = 1 +// bit_width( 0b0010 ) = 2 +// bit_width( 0b0011 ) = 2 +// bit_width( 0b0100 ) = 3 +// bit_width( 0b0101 ) = 3 +// bit_width( 0b0110 ) = 3 +// bit_width( 0b0111 ) = 3 +template +CUTE_HOST_DEVICE constexpr +T +bit_width(T x) { + static_assert(is_unsigned::value, "Only to be used for unsigned types."); + constexpr int N = (numeric_limits::digits == 64 ? 6 : + (numeric_limits::digits == 32 ? 5 : + (numeric_limits::digits == 16 ? 4 : + (numeric_limits::digits == 8 ? 3 : (assert(false),0))))); + T r = 0; + for (int i = N - 1; i >= 0; --i) { + T shift = (x > ((T(1) << (T(1) << i))-1)) << i; + x >>= shift; + r |= shift; + } + return r + (x != 0); +} + +// Smallest integral power of two not less than the given value +// bit_ceil( 0b00000000 ) = 0b00000001 +// bit_ceil( 0b00000001 ) = 0b00000001 +// bit_ceil( 0b00000010 ) = 0b00000010 +// bit_ceil( 0b00000011 ) = 0b00000100 +// bit_ceil( 0b00000100 ) = 0b00000100 +// bit_ceil( 0b00000101 ) = 0b00001000 +// bit_ceil( 0b00000110 ) = 0b00001000 +// bit_ceil( 0b00000111 ) = 0b00001000 +// bit_ceil( 0b00001000 ) = 0b00001000 +// bit_ceil( 0b00001001 ) = 0b00010000 +template +CUTE_HOST_DEVICE constexpr +T +bit_ceil(T x) { + return x == 0 ? T(1) : (T(1) << bit_width(x - 1)); +} + +// Largest integral power of two not greater than the given value +// bit_floor( 0b00000000 ) = 0b00000000 +// bit_floor( 0b00000001 ) = 0b00000001 +// bit_floor( 0b00000010 ) = 0b00000010 +// bit_floor( 0b00000011 ) = 0b00000010 +// bit_floor( 0b00000100 ) = 0b00000100 +// bit_floor( 0b00000101 ) = 0b00000100 +// bit_floor( 0b00000110 ) = 0b00000100 +// bit_floor( 0b00000111 ) = 0b00000100 +// bit_floor( 0b00001000 ) = 0b00001000 +// bit_floor( 0b00001001 ) = 0b00001000 +template +CUTE_HOST_DEVICE constexpr +T +bit_floor(T x) { + return x == 0 ? 0 : (T(1) << (bit_width(x) - 1)); +} + +template +CUTE_HOST_DEVICE constexpr T rotl(T x, int s); +template +CUTE_HOST_DEVICE constexpr T rotr(T x, int s); + +// Computes the result of circular bitwise left-rotation +template +CUTE_HOST_DEVICE constexpr +T +rotl(T x, int s) { + constexpr int N = numeric_limits::digits; + return static_cast(s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s)); +} + +// Computes the result of circular bitwise right-rotation +template +CUTE_HOST_DEVICE constexpr +T +rotr(T x, int s) { + constexpr int N = numeric_limits::digits; + return static_cast(s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s)); +} + +// Counts the number of consecutive 0 bits, starting from the most significant bit +// countl_zero( 0b00000000 ) = 8 +// countl_zero( 0b11111111 ) = 0 +// countl_zero( 0b00011100 ) = 3 +template +CUTE_HOST_DEVICE constexpr +T +countl_zero(T x) { + return numeric_limits::digits - bit_width(x); +} + +// Counts the number of consecutive 1 bits, starting from the most significant bit +// countl_one( 0b00000000 ) = 0 +// countl_one( 0b11111111 ) = 8 +// countl_one( 0b11100011 ) = 3 +template +CUTE_HOST_DEVICE constexpr +T +countl_one(T x) { + return countl_zero(~x); +} + +// Counts the number of consecutive 0 bits, starting from the least significant bit +// countr_zero( 0b00000000 ) = 8 +// countr_zero( 0b11111111 ) = 0 +// countr_zero( 0b00011100 ) = 2 +template +CUTE_HOST_DEVICE constexpr +T +countr_zero(T x) { + return x == 0 ? numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB +} + +// Counts the number of consecutive 1 bits, starting from the least significant bit +// countr_one( 0b00000000 ) = 0 +// countr_one( 0b11111111 ) = 8 +// countr_one( 0b11100011 ) = 2 +template +CUTE_HOST_DEVICE constexpr +T +countr_one(T x) { + return countr_zero(~x); +} + +// Counts the number of 1 bits in an unsigned integer +// popcount( 0b00000000 ) = 0 +// popcount( 0b11111111 ) = 8 +// popcount( 0b00011101 ) = 4 +template +CUTE_HOST_DEVICE constexpr +int +popcount(T x) { + int c = 0; + while (x) { + ++c; + x &= x - 1; // clear the least significant bit set + } + return c; +} + +// +// Custom operations +// + +// Computes the result of bitwise left-shift +template +CUTE_HOST_DEVICE constexpr +T +shiftl(T x, int s) { + return s >= 0 ? (x << s) : (x >> -s); +} + +// Computes the result of bitwise right-shift +template +CUTE_HOST_DEVICE constexpr +T +shiftr(T x, int s) { + return s >= 0 ? (x >> s) : (x << -s); +} + +// Safe divide +// @pre t % u == 0 +// @result t / u +template ::value && + is_std_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(T const& t, U const& u) { + //assert(t % u == 0); + return t / u; +} + +/** + * log2 computation + */ + +template +CUTE_HOST_DEVICE constexpr +int32_t +log_2(T x) { + assert(x > 0); + static_assert(is_unsigned::value, "Only to be used for unsigned integral types."); + return static_cast(bit_width(x)) - 1; +} + +} // namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/numeric_types.hpp b/server/punica_kernels/include/cutlass/cute/numeric/numeric_types.hpp new file mode 100644 index 00000000..6ee1e1e3 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/numeric_types.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include +#include + +namespace cute { + +template +struct sizeof_bits : public cutlass::sizeof_bits {}; + +// DO NOT change auto to int, sizeof_bits use integral_ratio instead of int +template +static constexpr auto sizeof_bits_v = sizeof_bits::value; + +using cutlass::bits_to_bytes; + +using cutlass::is_subbyte; + +template +static constexpr auto is_subbyte_v = is_subbyte::value; + +using cutlass::half_t; +using cutlass::bfloat16_t; + +using cutlass::tfloat32_t; + +// Umbrella floating-point 8-bit data type : type_erased_dynamic_float8_t +// This umbrella datatype can be enabled when a user provides a specific +// datatype in runtime argument list. +using cutlass::type_erased_dynamic_float8_t; +using cutlass::float_e4m3_t; +using cutlass::float_e5m2_t; + +using cutlass::uint1b_t; +using cutlass::int2b_t; +using cutlass::uint2b_t; +using cutlass::int4b_t; +using cutlass::uint4b_t; +using cutlass::bin1_t; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/numeric/real.hpp b/server/punica_kernels/include/cutlass/cute/numeric/real.hpp new file mode 100644 index 00000000..f797bc13 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/numeric/real.hpp @@ -0,0 +1,56 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +namespace cute +{ + +/// Generic fused multiply-add +template +CUTE_HOST_DEVICE constexpr +void +fma(D& d, A const& a, B const& b, C const& c) +{ + d = a * b + c; +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(A const& a, B const& b, C& c) +{ + return fma(c, a, b, c); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/pointer.hpp b/server/punica_kernels/include/cutlass/cute/pointer.hpp new file mode 100644 index 00000000..5647f97c --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/pointer.hpp @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include // sizeof_bits +#include +#include + +#include + +#include +#include +#include +namespace cute +{ + +// +// recast_ptr -- Create an iterator over values of type T. +// For most types this will simply be T*, but certain types require more care. +// Subbyte Types: uint2_t, uint4_t, etc +// Requires construction of a subbyte_iterator in order to properly +// resolve each element in byte-addressed memory. +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(void* ptr) +{ + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } else { + return reinterpret_cast(ptr); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(void const* ptr) +{ + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } else { + return reinterpret_cast(ptr); + } + CUTE_GCC_UNREACHABLE; +} + +// Disambiguate nullptr +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(decltype(nullptr)) { // nullptr_t + return recast_ptr(static_cast(nullptr)); +} + +// +// gmem_ptr +// + +template +struct gmem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; +}; + +template +struct is_gmem : false_type {}; +template // Found the gmem +struct is_gmem> : true_type {}; +template // Recurse on ::iterator, if possible +struct is_gmem> : is_gmem {}; + +// Idempotent gmem tag on an iterator +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(Iterator iter) { + if constexpr (is_gmem::value) { + return iter; + } else { + return gmem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(void* ptr) { + return make_gmem_ptr(recast_ptr(ptr)); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(void const* ptr) { + return make_gmem_ptr(recast_ptr(ptr)); +} + +// nullptr_t overload for make_gmem_ptr(nullptr) disambiguation +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(decltype(nullptr)) { // nullptr_t + return make_gmem_ptr(recast_ptr(nullptr)); +} + +// The gmem tag is invariant over type-recast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(gmem_ptr

const& ptr) { + return make_gmem_ptr(recast_ptr(ptr.get())); +} + +// +// smem_ptr +// + +template +struct smem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; +}; + +template +struct is_smem : false_type {}; +template // Found the smem +struct is_smem> : true_type {}; +template // Recurse on ::iterator, if possible +struct is_smem> : is_smem {}; + +// Idempotent smem tag on an iterator +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(Iterator iter) { + if constexpr (is_smem::value) { + return iter; + } else { + return smem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +// Make a smem swizzle pointer, common operation +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(Iterator ptr, Swizzle sw) +{ + return make_swizzle_ptr(make_smem_ptr(ptr), sw); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(void* ptr) { + return make_smem_ptr(recast_ptr(ptr)); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(void const* ptr) { + return make_smem_ptr(recast_ptr(ptr)); +} + +// The smem tag is invariant over type-recast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(smem_ptr

const& ptr) { + return make_smem_ptr(recast_ptr(ptr.get())); +} + +// +// rmem_ptr +// + +template +struct rmem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; +}; + +// Anything that is not gmem or smem is rmem +template +struct is_rmem : bool_constant::value || is_smem::value)> {}; +template +struct is_rmem> : true_type {}; + +// Idempotent rmem tag on an iterator +template +CUTE_HOST_DEVICE constexpr +auto +make_rmem_ptr(Iterator iter) { + if constexpr (is_rmem::value) { + return iter; + } else { + return rmem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_rmem_ptr(void* ptr) { + return make_rmem_ptr(recast_ptr(ptr)); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_rmem_ptr(void const* ptr) { + return make_rmem_ptr(recast_ptr(ptr)); +} + +// The rmem tag is invariant over type-recast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(rmem_ptr

const& ptr) { + return make_rmem_ptr(recast_ptr(ptr.get())); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(gmem_ptr ptr) +{ + printf("gmem_"); print(ptr.get()); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr ptr) +{ + printf("smem_"); print(ptr.get()); +} + +template +CUTE_HOST_DEVICE void print(rmem_ptr ptr) +{ + printf("rmem_"); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr ptr) +{ + return os << "gmem_[" << int(sizeof_bits>::value) << "b]"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr ptr) +{ + return os << "smem_[" << int(sizeof_bits>::value) << "b]"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr ptr) +{ + return os << "rmem_[" << int(sizeof_bits>::value) << "b]"; +} + +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/pointer_base.hpp b/server/punica_kernels/include/cutlass/cute/pointer_base.hpp new file mode 100644 index 00000000..db5d3dcf --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/pointer_base.hpp @@ -0,0 +1,247 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include // sizeof_bits + +namespace cute +{ + +// +// C++20 iterator_traits +// + +namespace detail { +// Default reference type of an iterator +template +struct iter_ref { using type = decltype(*declval()); }; +// Prefer to propagate ::reference +template +struct iter_ref> { using type = typename T::reference; }; +} // end namespace detail + +template +using iter_reference = detail::iter_ref; +template +using iter_reference_t = typename iter_reference::type; + +namespace detail { +// Default element_type of an iterator +template +struct iter_e { using type = remove_reference_t::type>; }; +// Prefer to propagate ::element_type +template +struct iter_e> { using type = typename T::element_type; }; +} // end namespace detail + +template +using iter_element = detail::iter_e; +template +using iter_element_t = typename iter_element::type; + +namespace detail { +// Default value_type of an iterator +template +struct iter_v { using type = remove_cv_t::type>; }; +// Prefer to propagate ::value_type +template +struct iter_v> { using type = typename T::value_type; }; +} // end namespace detail + +template +using iter_value = detail::iter_v; +template +using iter_value_t = typename iter_value::type; + +template +struct iterator_traits { + using reference = iter_reference_t; + using element_type = iter_element_t; + using value_type = iter_value_t; +}; + +// +// has_dereference to determine if a type is an iterator concept +// + +namespace detail { +template +struct has_dereference : CUTE_STL_NAMESPACE::false_type {}; +template +struct has_dereference())>> : CUTE_STL_NAMESPACE::true_type {}; +} // end namespace detail + +template +using has_dereference = detail::has_dereference; + +// +// raw_pointer_cast +// + +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(T* ptr) { + return ptr; +} + +// +// A very simplified iterator adaptor. +// Derived classed may override methods, but be careful to reproduce interfaces exactly. +// Clients should never have an instance of this class. Do not write methods that take this as a param. +// + +template +struct iter_adaptor +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + iterator ptr_; + + CUTE_HOST_DEVICE constexpr + iter_adaptor(iterator ptr = {}) : ptr_(ptr) {} + + CUTE_HOST_DEVICE constexpr + reference operator*() const { return *ptr_; } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return ptr_[i]; } + + template + CUTE_HOST_DEVICE constexpr + DerivedType operator+(Index const& i) const { return {ptr_ + i}; } + + CUTE_HOST_DEVICE constexpr + iterator get() const { return ptr_; } + + CUTE_HOST_DEVICE constexpr + friend bool operator==(DerivedType const& x, DerivedType const& y) { return x.ptr_ == y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator!=(DerivedType const& x, DerivedType const& y) { return x.ptr_ != y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator< (DerivedType const& x, DerivedType const& y) { return x.ptr_ < y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator<=(DerivedType const& x, DerivedType const& y) { return x.ptr_ <= y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator> (DerivedType const& x, DerivedType const& y) { return x.ptr_ > y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator>=(DerivedType const& x, DerivedType const& y) { return x.ptr_ >= y.ptr_; } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +raw_pointer_cast(iter_adaptor const& x) { + return raw_pointer_cast(x.ptr_); +} + +// +// counting iterator -- quick and dirty +// + +template +struct counting_iterator +{ + using index_type = T; + using value_type = T; + using reference = T; + + index_type n_; + + CUTE_HOST_DEVICE constexpr + counting_iterator(index_type n = 0) : n_(n) {} + + CUTE_HOST_DEVICE constexpr + index_type operator*() const { return n_; } + + CUTE_HOST_DEVICE constexpr + index_type operator[](index_type i) const { return n_ + i; } + + CUTE_HOST_DEVICE constexpr + counting_iterator operator+(index_type i) const { return {n_ + i}; } + CUTE_HOST_DEVICE constexpr + counting_iterator& operator++() { ++n_; return *this; } + CUTE_HOST_DEVICE constexpr + counting_iterator operator++(int) { counting_iterator ret = *this; ++n_; return ret; } + + CUTE_HOST_DEVICE constexpr + friend bool operator==(counting_iterator const& x, counting_iterator const& y) { return x.n_ == y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator!=(counting_iterator const& x, counting_iterator const& y) { return x.n_ != y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator< (counting_iterator const& x, counting_iterator const& y) { return x.n_ < y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator<=(counting_iterator const& x, counting_iterator const& y) { return x.n_ <= y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator> (counting_iterator const& x, counting_iterator const& y) { return x.n_ > y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator>=(counting_iterator const& x, counting_iterator const& y) { return x.n_ >= y.n_; } +}; + +template +CUTE_HOST_DEVICE constexpr +T +raw_pointer_cast(counting_iterator const& x) { + return x.n_; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(T const* const ptr) +{ + printf("ptr["); print(sizeof_bits::value); printf("b](%p)", ptr); +} + +template +CUTE_HOST_DEVICE void print(counting_iterator ptr) +{ + printf("counting_iter("); print(ptr.n_); printf(")"); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator ptr) +{ + return os << "counting_iter(" << ptr.n_ << ")"; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/pointer_flagged.hpp b/server/punica_kernels/include/cutlass/cute/pointer_flagged.hpp new file mode 100644 index 00000000..aa917d9f --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/pointer_flagged.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include // cast_smem_ptr_to_uint + +#include +#include +#include + +#include + +namespace cute +{ + +// +// Stand-in Swizzle Layout +// A model of a nullptr smem_ptr with B == sizeof_bits::value +// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr +// + +template +struct smem_ptr_flag_bits : Int<0> {}; + +using smem_ptr_flag = smem_ptr_flag_bits<1>; + +// A flagged construction method to transform ComposedLayout +// Make a swizzle pointer tensor and check that the intended type size matches +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& ptr, + ComposedLayout,Layout> const& layout) +{ + static_assert(is_smem::value, "Expected smem."); + static_assert(B == sizeof_bits>::value, "Expected a B-bit pointer type."); + return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()), + layout.layout_b()); +} + +// NOTE: To preserve smem_ptr_flag_bits under recast ops +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.layout_a(), smem_ptr_flag_bits{}, upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.layout_a(), smem_ptr_flag_bits{}, downcast(layout.layout_b())); +} + +// +// Conversion with swizzle_layout +// + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) +{ + return composition(recast_layout>(layout.layout_a()), Int<0>{}, layout.layout_b()); +} + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_tensor(Tensor&& tensor) +{ + static_assert(is_smem>::value, "Expected smem tensor."); + using SwizzleFn = get_swizzle_t>; + if constexpr (SwizzleFn::num_bits == 0) { + return tensor; + } else { +#if !defined(NDEBUG) + { + uint32_t address = cast_smem_ptr_to_uint(raw_pointer_cast(static_cast(tensor).data())); + uint32_t mask = ((uint32_t(1) << SwizzleFn::num_base) - 1) | SwizzleFn::swizzle_code; + assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle + } +#endif + using T = typename remove_cvref_t::value_type; + // Recast swizzle from acting on byte-addressed pointers to elements of type-T + auto new_swizzle = recast_layout(SwizzleFn{}); + // Strip off everything and create a new smem_ptr for type-T + auto new_ptr = make_smem_ptr(raw_pointer_cast(static_cast(tensor).data())); + return make_tensor(new_ptr, composition(new_swizzle, Int<0>{}, tensor.layout())); + } + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +// Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_latex(ComposedLayout,Layout> const& layout) +{ + print_latex(as_position_independent_swizzle_layout(layout)); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr_flag_bits ptr) +{ + printf("smem_ptr[%db](unset)", B); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/pointer_swizzle.hpp b/server/punica_kernels/include/cutlass/cute/pointer_swizzle.hpp new file mode 100644 index 00000000..a83b485c --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/pointer_swizzle.hpp @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include // iterator_traits +#include + +#include +#include + +/* This implements a swizzle pointer of the form + * InvolutionFn o PtrAdd + * where the InvolutionFn need not be linear. + * + * This differs subtly from swizzle_layout because the smem pointer is used + * as the offset. That means that swizzle_layout will implement position-independent + * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. + * Arch chose to design hardware with position-dependent swizzles. + * + * For clarity: + * NormalLayout : DeRef <- PtrAdd <- [Layout] + * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] + * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout + * + * Furthermore, for known swizzles, this pointer attempts to decay itself + * to a normal-pointer with a new layout containing dynamic or static strides. + * This is possible by determining the subdomain of the InvolutionFn + * that is identity and testing if the Layout's codomain is contained + * within it. + */ + +namespace cute +{ + +// concept SwizzleFn { +// CUTE_HOST_DEVICE constexpr static uint apply(uint); +// } +// See Swizzle in swizzle.hpp for common swizzle-functions. + +template +struct swizzle_ptr : iter_adaptor> +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + using iter_adaptor>::iter_adaptor; + + template + CUTE_HOST_DEVICE constexpr static + Iter apply_swizzle(Iter ptr) { + return {apply_swizzle(ptr.get())}; + } + + template + CUTE_HOST_DEVICE constexpr static + T* apply_swizzle(T* ptr) { + return reinterpret_cast(SwizzleFn::apply(reinterpret_cast(ptr))); + } + + template + CUTE_HOST_DEVICE constexpr static + subbyte_iterator apply_swizzle(subbyte_iterator ptr) { + return {apply_swizzle(ptr.ptr_), ptr.idx_}; + } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return *apply_swizzle(this->get()); + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Int const& i) const { + return *apply_swizzle(this->get() + i); + } +}; + +template // Default No-Swizzle +struct get_swizzle { using type = Swizzle<0,4,3>; }; +template // Found the SwizzleFn +struct get_swizzle> { using type = SwizzleFn; }; +template // Recurse into anything with a ::iterator +struct get_swizzle> : get_swizzle {}; + +template +using get_swizzle_t = typename get_swizzle::type; + +template +CUTE_HOST_DEVICE constexpr +swizzle_ptr +make_swizzle_ptr(Iterator ptr, SwizzleFn) { + return {ptr}; +} + +// Swizzle-0 specialization for immediate decay +template +CUTE_HOST_DEVICE constexpr +Iterator +make_swizzle_ptr(Iterator ptr, Swizzle<0,M,S>) { + return ptr; +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +raw_pointer_cast(swizzle_ptr const& ptr) { + return raw_pointer_cast(ptr.get()); +} + +// SwizzleFn operates on the pointer address, so it doesn't care about the type +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(swizzle_ptr const& ptr) { + return make_swizzle_ptr(recast_ptr(ptr.get()), SwizzleFn{}); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(swizzle_ptr ptr) +{ + print(SwizzleFn{}); printf("_"); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, swizzle_ptr ptr) +{ + return os << SwizzleFn{} << "_" << ptr.get(); +} +#endif + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/stride.hpp b/server/punica_kernels/include/cutlass/cute/stride.hpp new file mode 100644 index 00000000..eb62e4bf --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/stride.hpp @@ -0,0 +1,471 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +/** crd2idx(c,s,d) maps a coordinate within to an index + * + * This is computed as follows: + * [coord, shape, and stride are all integers => step forward by stride] + * op(c, s, d) => c * d + * [coord is integer, shape and stride are tuple => divmod coord for each mode] + * op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D)) + * [coord, shape, and stride are all tuples => consider each mode independently] + * op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D)) + */ +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape, + Stride const& stride); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_ttt(Coord const& coord, + Shape const& shape, + Stride const& stride, seq) +{ + return (... + crd2idx(get(coord), get(shape), get(stride))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_itt(CInt const& coord, + STuple const& shape, + DTuple const& stride, seq) +{ + if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter + return crd2idx(coord, get(shape), get(stride)); + } else if constexpr (is_constant<0, CInt>::value) { + return crd2idx(_0{}, get(shape), get(stride)) + + (_0{} + ... + crd2idx(_0{}, get(shape), get(stride))); + } else { // General case + return crd2idx(coord % product(get(shape)), get(shape), get(stride)) + + crd2idx_itt(coord / product(get(shape)), shape, stride, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape, + Stride const& stride) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple tuple + static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2idx_ttt(coord, shape, stride, tuple_seq{}); + } else { // tuple "int" "int" + static_assert(sizeof(Coord) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { // "int" tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2idx_itt(coord, shape, stride, tuple_seq{}); + } else { // "int" "int" "int" + return coord * stride; + } + } + + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_horner(CTuple const& coord, + STuple const& shape, seq) +{ + if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter + return get(coord); + } else { // General case + return get(coord) + get(shape) * crd2idx_horner(coord, shape, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +/** crd2idx(c,s) maps a coordinate within Shape to an index + * via a colexicographical enumeration of coordinates in Shape. + * i = c0 + s0 * (c1 + s1 * (c2 + s2 * ...)) + */ +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape) +{ + if constexpr (is_integral::value) { // Coord is already an index + return coord; + } else if constexpr (is_integral::value) { + static_assert(dependent_false, "Invalid parameters"); + } else { // Make congruent, flatten, and apply Horner's method + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + auto flat_coord = flatten(coord); + auto flat_shape = flatten(product_like(shape, coord)); + return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +/** idx2crd(i,s,d) splits an index into a coordinate within . + * + * This is computed as follows: + * [index, shape, and stride are all integers => determine 1D coord] + * op(i, s, d) => (i / d) % s + * [index is integer, shape and stride are tuple => determine component for each mode] + * op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...) + * [index, shape, and stride are all tuples => consider each mode independently] + * op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D))) + * + * NOTE: This only works for compact shape+stride layouts. A more general version would + * apply to all surjective layouts + */ +template +CUTE_HOST_DEVICE constexpr +auto +idx2crd(Index const& idx, + Shape const& shape, + Stride const& stride) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple tuple + static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); }); + } else { // tuple "int" "int" + static_assert(sizeof(Index) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // "int" tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); + } else { // "int" tuple "int" + return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); + } + } else { // "int" "int" "int" + if constexpr (is_constant<1, Shape>::value) { + // Skip potential stride-0 division + return Int<0>{}; + } else { + return (idx / stride) % shape; + } + } + } + + CUTE_GCC_UNREACHABLE; +} + +/** idx2crd(i,s) splits an index into a coordinate within Shape + * via a colexicographical enumeration of coordinates in Shape. + * c0 = (idx / 1) % s0 + * c1 = (idx / s0) % s1 + * c2 = (idx / (s0 * s1)) % s2 + * ... + */ +template +CUTE_HOST_DEVICE constexpr +auto +idx2crd(Index const& idx, + Shape const& shape) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); }); + } else { // tuple "int" + static_assert(sizeof(Index) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { // "int" tuple + return idx2crd(idx, shape, compact_col_major(shape)); + } else { // "int" "int" + return idx; + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// crd2crd +// + +template +CUTE_HOST_DEVICE constexpr +auto +crd2crd(Coord const& coord, + SShape const& src_shape, + DShape const& dst_shape) +{ + if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); }); + } else { + // assert(size(src_shape) == size(dst_shape)) + return idx2crd(crd2idx(coord, src_shape), dst_shape); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Compact Major +// + +// Tags for common layouts and dispatching +struct LayoutLeft; // Col-major layout mapping; leftmost extent has stride 1 +using GenColMajor = LayoutLeft; // Alias + +struct LayoutRight; // Row-major layout mapping; rightmost extent has stride 1 +using GenRowMajor = LayoutRight; // Alias + +namespace detail { + +// For GCC8.5 -- Use of lambdas in unevaluated contexts. Instead use function objects. +template +struct CompactLambda; + +// @pre is_integral +// Return (result, current * product(shape)) to enable recurrence +template +CUTE_HOST_DEVICE constexpr +auto +compact(Shape const& shape, + Current const& current) +{ + if constexpr (is_tuple::value) { // Shape::tuple Current::int + using Lambda = CompactLambda; // Append or Prepend + using Seq = typename Lambda::template seq; // Seq or RSeq + return cute::detail::fold(shape, cute::make_tuple(cute::make_tuple(), current), Lambda{}, Seq{}); + } else { // Shape::int Current::int + if constexpr (is_constant<1, Shape>::value) { + return cute::make_tuple(Int<0>{}, current); // If current is dynamic, this could save a reg + } else { + return cute::make_tuple(current, current * shape); + } + } + + CUTE_GCC_UNREACHABLE; +} + +// For GCC8.5 -- Specialization LayoutLeft +template <> +struct CompactLambda +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Init const& init, Shape const& si) { + auto result = detail::compact(si, get<1>(init)); + return cute::make_tuple(append(get<0>(init), get<0>(result)), get<1>(result)); // Append + } + + template + using seq = tuple_seq; // Seq +}; + +// For GCC8.5 -- Specialization LayoutRight +template <> +struct CompactLambda +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Init const& init, Shape const& si) { + auto result = detail::compact(si, get<1>(init)); + return cute::make_tuple(prepend(get<0>(init), get<0>(result)), get<1>(result)); // Prepend + } + + template + using seq = tuple_rseq; // RSeq +}; + +} // end namespace detail + +template , + __CUTE_REQUIRES(is_tuple::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +compact_major(Shape const& shape, + Current const& current = {}) +{ + if constexpr (is_tuple::value) { // Shape::tuple Current::tuple + static_assert(is_tuple::value, "Invalid parameters"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + // Recurse to apply to the terminals of current + return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c); }); + } else { + return get<0>(detail::compact(shape, current)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Compact Col Major +// + +struct LayoutLeft { + template + using Apply = decltype(compact_major(declval())); +}; + +template > +CUTE_HOST_DEVICE constexpr +auto +compact_col_major(Shape const& shape, + Current const& current = {}) +{ + return compact_major(shape, current); +} + +// +// Compact Row Major +// + +struct LayoutRight { + template + using Apply = decltype(compact_major(declval())); +}; + +template > +CUTE_HOST_DEVICE constexpr +auto +compact_row_major(Shape const& shape, + Current const& current = {}) +{ + return compact_major(shape, current); +} + +// +// Compact Order -- compute a compact stride based on an ordering of the modes +// + +namespace detail { + +// @pre weakly_congruent(order, shape) +// @pre is_congruent +// @pre is_static +// @pre is_static +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, Order const& order, + RefShape const& ref_shape, RefOrder const& ref_order) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Need equal rank of shape and order"); + return transform(shape, order, [&](auto const& s, auto const& o) { return compact_order(s, o, ref_shape, ref_order); }); + } else { + // Compute the starting stride for this shape by accumulating all shapes corresponding to lesser orders + auto stride_start = product(transform(ref_shape, ref_order, + [&](auto const& s, auto const& o) { + return conditional_return(o < order, s, Int<1>{}); + })); + return compact_col_major(shape, stride_start); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, Order const& order) +{ + auto ref_shape = flatten_to_tuple(product_like(shape, order)); + + auto flat_order = flatten_to_tuple(order); + // Find the largest static element of order + auto max_order = cute::fold(flat_order, Int<0>{}, [](auto v, auto order) { + if constexpr (is_constant::value) { + return order; + } else { + return v; + } + }); + // Replace any dynamic elements within order with large-static elements + auto max_seq = make_range{}; + auto ref_order = cute::transform(max_seq, flat_order, [](auto seq_v, auto order) { + if constexpr (is_static::value) { + return order; + } else { + return seq_v; + } + }); + + auto new_order = unflatten(ref_order, order); + + return detail::compact_order(shape, new_order, ref_shape, ref_order); +} + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, GenColMajor const& major) +{ + return compact_major(shape); +} + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, GenRowMajor const& major) +{ + return compact_major(shape); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/swizzle.hpp b/server/punica_kernels/include/cutlass/cute/swizzle.hpp new file mode 100644 index 00000000..57735ce1 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/swizzle.hpp @@ -0,0 +1,475 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace cute +{ + +// A generic Swizzle functor +/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + * ^--^ MBase is the number of least-sig bits to keep constant + * ^-^ ^-^ BBits is the number of bits in the mask + * ^---------^ SShift is the distance to shift the YYY mask + * (pos shifts YYY to the right, neg shifts YYY to the left) + * + * e.g. Given + * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + * the result is + * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY + */ +template +struct Swizzle +{ + static constexpr int num_bits = BBits; + static constexpr int num_base = MBase; + static constexpr int num_shft = SShift; + + static_assert(num_base >= 0, "MBase must be positive."); + static_assert(num_bits >= 0, "BBits must be positive."); + static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits."); + + // using 'int' type here to avoid unintentially casting to unsigned... unsure. + using bit_msk = cute::constant; + using yyy_msk = cute::constant; + using zzz_msk = cute::constant; + using msk_sft = cute::constant; + + static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{}); + + template + CUTE_HOST_DEVICE constexpr static + auto + apply(Offset const& offset) + { + return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY + } + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Offset const& offset) const + { + return apply(offset); + } + + template + CUTE_HOST_DEVICE constexpr + auto + operator==(Swizzle const&) const + { + return B == BBits && M == MBase && S == SShift; + } +}; + +// +// make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1> +// make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5> +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle() +{ + constexpr uint32_t BZ = popcount(Y); // Number of swizzle bits + constexpr uint32_t BY = popcount(Z); // Number of swizzle bits + static_assert(BZ == BY, "Number of bits in Y and Z don't match"); + constexpr uint32_t TZ_Y = countr_zero(Y); // Number of trailing zeros in Y + constexpr uint32_t TZ_Z = countr_zero(Z); // Number of trailing zeros in Z + constexpr uint32_t M = cute::min(TZ_Y, TZ_Z) % 32; + constexpr int32_t S = int32_t(TZ_Y) - int32_t(TZ_Z); // Difference in trailing zeros + static_assert((Y | Z) == Swizzle::swizzle_code, "Something went wrong."); + return Swizzle{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle, Swizzle) +{ + static_assert(S0 == S1, "Can only merge swizzles of the same shift."); + constexpr uint32_t Y = Swizzle::yyy_msk::value ^ Swizzle::yyy_msk::value; + constexpr uint32_t Z = Swizzle::zzz_msk::value ^ Swizzle::zzz_msk::value; + return make_swizzle(); + + //return ComposedFn, Swizzle>{}; +} + +// +// Utility for slicing and swizzle "offsets" +// + +// For swizzle functions, it is often needed to keep track of which bits are +// consumed and which bits are free. Furthermore, it is useful to know whether +// each of these bits is known statically or dynamically. + +// MixedBits is an 32-bit unsigned integer class where some bits are known statically +// and some bits are known dynamically. These sets of bits are disjoint and it is +// known statically which bits are known dynamically. + +// MixedBits can only be manipulated through bitwise operations + +// Abstract value: StaticInt | (dynamic_int_ & StaticFlags) +template // 0: static, 1: dynamic +struct MixedBits +{ + // Representation invariants + static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits."); + static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits."); + + uint32_t dynamic_int_; + // assert((dynamic_int_ & ~StaticFlags) == 0); + + CUTE_HOST_DEVICE constexpr operator uint32_t() const noexcept { return StaticInt | dynamic_int_; } +}; + +// Return a value representing (C{} | (d & C)) potentially using MixedBits to track s and f. +// This maker does allow ((s & f) != 0) and enforces the MixedBits invariant before creation. +template +CUTE_HOST_DEVICE constexpr +auto +make_mixed_bits(C, DynamicType const& d, C) +{ + static_assert(is_integral::value); + constexpr uint32_t new_f = uint32_t(f) & ~uint32_t(s); // StaticBits take precedence, M<0,f>{d} | C{} + if constexpr (new_f == 0 || is_static::value) { + return C{} | (d & C{}); // Just return a static int + } else { + return MixedBits{uint32_t(d) & new_f}; // MixedBits + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Operators +// + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(MixedBits const& m, C) +{ + return (S0 == (uint32_t(S1) & ~F0)) && (m.dynamic_int_ == (uint32_t(S1) & F0)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator==(C s, MixedBits const& m) +{ + return m == s; +} + +// Bitwise AND +template +CUTE_HOST_DEVICE constexpr +auto +operator&(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 0X0 | 0X0 | 0X0 | + // 001 | 0X0 | 001 | 001 | 001 | + // 011 | 0X0 | 001 | 011 | 011 | + // 1X0 | 0X0 | 001 | 011 | 1X0 | + + return make_mixed_bits(C{}, + //(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_), + ((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_), + C<(S1 & F0) | (S0 & F1) | (F0 & F1)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator&(MixedBits const& m, C) +{ + return make_mixed_bits(C{}, + m.dynamic_int_, + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator&(C s, MixedBits const& m) +{ + return m & s; +} + +// Bitwise OR +template +CUTE_HOST_DEVICE constexpr +auto +operator|(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 001 | 011 | 1X0 | + // 001 | 001 | 001 | 011 | 1X0 | + // 011 | 011 | 011 | 011 | 1X0 | + // 1X0 | 1X0 | 1X0 | 1X0 | 1X0 | + + return make_mixed_bits(C{}, + ((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_), + C<(~S0 & F1) | (~S1 & F0)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator|(MixedBits const& m, C) +{ + return make_mixed_bits(C{}, + m.dynamic_int_, + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator|(C s, MixedBits const& m) +{ + return m | s; +} + +// Bitwise XOR +template +CUTE_HOST_DEVICE constexpr +auto +operator^(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 001 | 011 | 1X0 | + // 001 | 001 | 001 | 011 | 011 | + // 011 | 011 | 011 | 001 | 001 | + // 1X0 | 1X0 | 011 | 001 | 0X0 | + + return make_mixed_bits(C<(~S0 & S1 & ~F0) | (S0 & ~S1 & ~F1)>{}, + (S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_), + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator^(MixedBits const& m, C) +{ + return make_mixed_bits(C<(~S0 & uint32_t(S1) & ~F0) | (S0 & ~uint32_t(S1))>{}, + (S0 | m.dynamic_int_) ^ uint32_t(S1), + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator^(C s, MixedBits const& m) +{ + return m ^ s; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator<<(MixedBits const& m, C) +{ + return make_mixed_bits(C<(S0 << S1)>{}, + m.dynamic_int_ << S1, + C<(F0 << S1)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator>>(MixedBits const& m, C) +{ + return make_mixed_bits(C<(S0 >> S1)>{}, + m.dynamic_int_ >> S1, + C<(F0 >> S1)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +shiftl(MixedBits const& m, C s) +{ + if constexpr (S1 >= 0) { + return m << s; + } else { + return m >> -s; + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +shiftr(MixedBits const& m, C s) +{ + if constexpr (S1 >= 0) { + return m >> s; + } else { + return m << -s; + } +} + +// +// upcast and downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(MixedBits const& m, C s) +{ + static_assert(has_single_bit(uint32_t(S1)), "Only divide MixedBits by powers of two."); + return make_mixed_bits(safe_div(C{}, s), + safe_div(m.dynamic_int_, s), + safe_div(C{}, s)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(MixedBits const& m) +{ + static_assert(has_single_bit(N), "Only divide MixedBits by powers of two."); + return safe_div(m, C{}); +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +upcast(T const& m) +{ + return safe_div(m, C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(MixedBits const& m) +{ + static_assert(has_single_bit(N), "Only scale MixedBits by powers of two."); + return make_mixed_bits(C{}, + m.dynamic_int_ * N, + C{}); +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +downcast(T const& m) +{ + return m * C{}; +} + +// +// Convert a Pow2Layout+Coord to a MixedBits +// + +template +CUTE_HOST_DEVICE constexpr +auto +to_mixed_bits(Shape const& shape, Stride const& stride, Coord const& coord) +{ + if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(shape, stride, coord, [](auto const& s, auto const& d, auto const& c) { return to_mixed_bits(s,d,c); }, + [](auto const&... a) { return (a ^ ...); }); + } else if constexpr (is_integral::value && is_integral::value && is_integral::value) { + static_assert(decltype(shape*stride)::value == 0 || has_single_bit(decltype(shape*stride)::value), "Requires pow2 shape*stride."); + return make_mixed_bits(Int<0>{}, coord * stride, (shape - Int<1>{}) * stride); + } else { + static_assert(is_integral::value && is_integral::value && is_integral::value, "Either Shape, Stride, and Coord must be all tuples, or they must be all integral (in the sense of cute::is_integral)."); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_mixed_bits(Layout const& layout, Coord const& coord) +{ + return to_mixed_bits(layout.shape(), layout.stride(), idx2crd(coord, layout.shape())); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Swizzle const&) +{ + printf("Sw<%d,%d,%d>", B, M, S); +} + +template +CUTE_HOST_DEVICE void print(MixedBits const& m) +{ + printf("M_%u|(%u&%u)=%u", S, m.dynamic_int_, F, uint32_t(m)); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) +{ + return os << "Sw<" << B << "," << M << "," << S << ">"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) +{ + return os << "M_" << S << "|(" << m.dynamic_int_ << "&" << F << ")=" << uint32_t(m); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/swizzle_layout.hpp b/server/punica_kernels/include/cutlass/cute/swizzle_layout.hpp new file mode 100644 index 00000000..f795562c --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/swizzle_layout.hpp @@ -0,0 +1,561 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +#include + +/* Specialized functionality for a ComposedLayout of the form + * InvolutionFn o Offset o LayoutB + * where the InvolutionFn is a Swizzle and is not linear (hence the need for the Offset). + * + * Because these are specializations for core functions of ComposedLayout, these Swizzle Layouts + * provide similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but are not considered "normal" layouts. + * For example, these provide shape() and size() functions, but do not provide stride() functions. + * + * Furthermore, each of these specializations uses Swizzle<>-specific knowledge in its implementation and + * attempts to decay itself to a normal-layout with dynamic or static strides when certain slicing conditions + * are met. This is possible by determining the subdomain of the Swizzle<> function that is identity and + * testing if LayoutB's codomain is contained within it. In general, MizedBits is used as the Offset to track + * statically-vs-dynamically known bits in the Offset to improve the decay to static or dynamic normal layouts. + */ + +namespace cute +{ + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Swizzle const& sxor) +{ + return composition(sxor, Layout,Int<1>>{}); +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transfer_swizzle(Layout const& old_layout, + Layout const& new_layout) +{ + // Our goal is to determine a new swizzle for the strides in new_layout for consistent vectorizations + + // This is accomplished by identifying + // S o L :=: S? o L* + // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S + // Then that active identifier is transformed through the layouts: + // L*(L[(P o L)(c*)]) + // which is a new swizzle identifier for S?, the new swizzle + + // Projections of the swizzle layout for composition, P + auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), + make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); + + // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] + auto layout_only_zy = composition(swizzle_only_zy, old_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); + + // Get the Z bit and the Y bits -- keep only those that are active in Z *and* Y + auto zzz_msk = typename Swizzle::zzz_msk{}; + auto yyy_msk = typename Swizzle::yyy_msk{}; + auto msk_sft = typename Swizzle::msk_sft{}; + auto active_Z = swizzle_active_bits & shiftr(swizzle_active_bits, msk_sft) & zzz_msk; + auto active_Y = swizzle_active_bits & shiftr(swizzle_active_bits, -msk_sft) & yyy_msk; + + // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) + auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); + auto new_active_Y = new_layout(old_layout.get_1d_coord(active_Y)); + + // Use this new swizzle identifier to construct the new swizzle for new_layout + // (this also makes sure it's a "valid" swizzle that Swizzle can represent) + return composition(make_swizzle(), new_layout); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(ComposedLayout,Offset,Layout> const& layout) +{ + return detail::transfer_swizzle(layout.layout_b(), make_fragment_like(layout.layout_b())); +} + +// +// Utilities +// + +namespace detail { + +// Get just the Swizzle part of a composed layout. +template +CUTE_HOST_DEVICE constexpr +auto +get_swizzle_portion(ComposedLayout,Offset,LayoutB>) +{ + return Swizzle{}; +} + +// A non-swizzled layout's "Swizzle part" is the identity swizzle. +template +CUTE_HOST_DEVICE constexpr +auto +get_swizzle_portion(Layout) +{ + return Swizzle<0,4,3>{}; +} + +// Get the "non-swizzle" part of a composed layout, +// which is the underlying (non-composed) Layout. +template +CUTE_HOST_DEVICE constexpr +auto +get_nonswizzle_portion(ComposedLayout,Offset,LayoutB> const& slayout) +{ + return slayout.layout_b(); +} + +// The non-swizzle part of a non-swizzled layout is just the Layout. +template +CUTE_HOST_DEVICE constexpr +auto +get_nonswizzle_portion(Layout const& slayout) +{ + return slayout; +} + +} // namespace detail + +// +// Slice a Swizzled ComposedLayout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle_strides(true_type, + IntZ const& Z, + IntY const& Y, + Offset const& offset, + int_sequence) +{ + // Below is an optimized/compressed version of: + //return cute::make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); + // with knowledge of Swizzle, I... ranges for each B bits, + // and the layout won't slice along z-bits that are already set + + // y\z 0 1 + // 0 Z DC + // 1 -Z DC + + return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z << Int{}, -(Z << Int{}))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle_strides(false_type, + IntZ const& Z, + IntY const& Y, + Offset const& offset, + int_sequence) +{ + // Below is an optimized/compressed version of: + //return cute::make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); + // with knowledge of Swizzle, I... ranges for each B bits, + // and the layout won't slice along y-bits that are already set + + // y\z 0 1 + // 0 Y+Z Y-Z + // 1 DC DC + + return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) << Int{}, (Y-Z) << Int{})...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout> const& layout) +{ + if constexpr (all_underscore::value) { + // Skip the expensive/complicated attempt to decay to a normal layout and just reshape + return cute::make_tuple(composition(layout.layout_a(), layout.offset(), slice(coord, layout.layout_b())), Int<0>{}); + } else { + + // Projections of the swizzle layout for composition + auto sw = make_layout(make_shape(Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B)>{}, Int<1>{})); + + auto swizzle_anti_zy = make_layout(shape(sw), + make_stride(stride<0>(sw), Int<0>{}, stride<2>(sw), Int<0>{}, size(sw))); + auto swizzle_only_zy = make_layout(shape(sw), + make_stride( Int<0>{}, stride<1>(sw), Int<0>{}, stride<3>(sw), Int<0>{})); + + // The portion of the layout that is not yet consumed + auto sliced_layout = slice(coord, layout.layout_b()); + + // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay + + // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] + // (this also tests that shape/stride of layout compose with swizzle) + auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); + // Determine if any active bits collide under the swizzle + auto hit_ZandY = !(swizzle_active_bits & ~layout.layout_a()(swizzle_active_bits)); + + // The portion of the layout that we are consuming now + auto diced_layout = dice(coord, layout.layout_b()); + auto diced_coord = dice(coord, coord); + + auto diced_layout_anti_zy = composition(swizzle_anti_zy, diced_layout); + auto diced_layout_only_zy = composition(swizzle_only_zy, diced_layout); + + // New swizzle and offset + auto swizzle = layout.layout_a(); + // offset_only_zy interacts with swizzle and gets accumulated with layout.offset() + // being careful about the static/dynamic contributions from diced_layout and diced_coord + auto offset_only_zy = layout.offset() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); + // offset_anti_zy always gets passed through, no interaction with swizzle + auto offset_anti_zy = diced_layout_anti_zy(diced_coord); + + // If Layout's codomain hits on Y AND Z, then it's not reducible + // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal + // If Layout's codomain hits on neither Y NOR Z, then it's static-normal + + // Test the sliced layout for hit_X & hit_Y for potential decay + if constexpr (is_constant::value) + { // Hits on Y AND Z, so it's not reducible + return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); + } else + { // Misses on Y or Z, so it's static-normal or dynamic-normal + + // Lowest bit of the Z and Y masks + auto Z = typename Swizzle::zzz_msk{} & -typename Swizzle::zzz_msk{}; + auto Y = typename Swizzle::yyy_msk{} & -typename Swizzle::yyy_msk{}; + auto stride_lo = detail::make_swizzle_strides(Z < Y, Z, Y, offset_only_zy, make_int_sequence{}); + auto stride_hi = detail::make_swizzle_strides(Z > Y, Z, Y, offset_only_zy, make_int_sequence{}); + + // Construct a (dynamic) layout that we can perform the composition with + auto swizzle_layout = make_layout(make_shape (Int<(1 << M)>{}, repeat(Int<2>{}), Int<(1 << (abs(S)-B))>{}, repeat(Int<2>{}), Int< 1>{}), + make_stride(Int< 1>{}, stride_lo, Int<(1 << (M+B))>{}, stride_hi , Int<(1 << (M+B+abs(S)))>{})); + + // Decay to a normal layout with offset + return cute::make_tuple(composition(swizzle_layout, sliced_layout), + swizzle(offset_only_zy) + offset_anti_zy); + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// composition +// + +// Ignore identity case +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle<0,M,S> const&, + Int<0> const&, + Layout const& layout) +{ + return layout; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle const& sxor, + Layout const& layout) +{ + return composition(sxor, Int<0>{}, layout); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + Swizzle const& b) +{ + // Get the Z bits and the Y bits + auto active_Y = a(typename Swizzle::yyy_msk{}); + auto active_Z = a(typename Swizzle::zzz_msk{}); + + // Works in simple cases... but could be greatly generalized + + return composition(make_swizzle(), a); +} + +// +// inverse +// + +// Specialization to attempt to pass-through the Swizzle back to the left -- Needed? +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout,Offset,Layout> const& layout) +{ + if constexpr (is_constant<0, Offset>::value) { + return composition(right_inverse(layout.layout_b()), layout.layout_a()); + } else { + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); + } +} + +// Specialization to attempt to pass-through the Swizzle back to the left -- Needed? +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout,Offset,Layout> const& layout) +{ + if constexpr (is_constant<0, Offset>::value) { + return composition(left_inverse(layout.layout_b()), layout.layout_a()); + } else { + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); + } +} + +template +CUTE_HOST_DEVICE constexpr +Swizzle +right_inverse(Swizzle const& sw) +{ + return sw; +} + +template +CUTE_HOST_DEVICE constexpr +Swizzle +left_inverse(Swizzle const& sw) +{ + return sw; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +right_inverse(T const& t) +{ + return -t; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +left_inverse(T const& t) +{ + return -t; +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + constexpr int NewM = M - log2_n; + if constexpr (NewM >= 0) { + return Swizzle{}; + } else { + return Swizzle{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + return Swizzle{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(Swizzle const& swizzle) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return swizzle; + } + else if constexpr (scale::num == 1) { + return downcast(swizzle); + } + else if constexpr (scale::den == 1) { + return upcast(swizzle); + } + else { + static_assert(dependent_false, "Recast not supported."); + } + CUTE_GCC_UNREACHABLE; +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(ComposedLayout,Offset,LayoutB> const& a, + Layout const& b) +{ + auto common = max_common_layout(a.layout_b(), b); + auto base = Int<(1 << M)>{}; + if constexpr (base < size(common)) { + return common.compose(base); // Truncate common to size base + } else { + return common; + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Layout const& a, + ComposedLayout,Offset,LayoutB> const& b) +{ + return max_common_layout(b, a); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(ComposedLayout,Offset,LayoutB> const& a, + Layout const& b) +{ + // This assumes that Offset is in the YZ domain of the Swizzle... + return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + ComposedLayout,Offset,LayoutB> const& b) +{ + return max_common_vector(b, a); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(ComposedLayout,Offset0,LayoutB0> const& a, + ComposedLayout,Offset1,LayoutB1> const& b) +{ + auto result = coalesce(composition(a, right_inverse(b))); + + if constexpr (is_constant<1, decltype(stride<0>(result.layout_b()))>::value) { + return shape<0>(result); + } else { + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/////////////////////////////////////////////////////////////////////////////// +// ComposedLayout as second argument is often more difficult... + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& layout, + ComposedLayout,Offset,LayoutT> const& tiler) +{ + CUTE_STATIC_ASSERT_V(tiler.offset() == Int<0>{}, "Require Swizzle offset == 0."); + // The new layout -- if swizzle wasn't an issue, this is the result + // our goal is to determine a new swizzle for these strides + auto new_layout = logical_product(layout, tiler.layout_b()); + + // This is accomplished by identifying + // S o L :=: S? o L* + // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S + // Then that active identifier is transformed through the layouts: + // L*(L[(P o L)(c*)]) + // which is a new swizzle identifier for S?, the new swizzle + + // Projections of the swizzle layout for composition, P + auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), + make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); + + // Compose with the tiler to get the swizzle projection, P o L [The Z and Y contributing portions of L] + auto layout_only_zy = composition(swizzle_only_zy, tiler.layout_b()); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); + // Get the Z bit and the Y bits + auto active_Z = swizzle_active_bits & typename Swizzle::zzz_msk{}; + auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; + + // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) + auto new_active_Z = new_layout(Int<0>{}, tiler.layout_b()[active_Z]); + auto new_active_Y = new_layout(Int<0>{}, tiler.layout_b()[active_Y]); + + // Use this new swizzle identifier to construxt the new swizzle for new_layout + // (this also makes sure it's a "valid" swizzle that Swizzle can represent) + return composition(make_swizzle(), new_layout); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/tensor.hpp b/server/punica_kernels/include/cutlass/cute/tensor.hpp new file mode 100644 index 00000000..28d3ee67 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/tensor.hpp @@ -0,0 +1,1099 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace cute +{ + +// +// Engine -- owning or non-owning data store +// + +// concept Engine { +// using iterator = ; +// using value_type = ; +// using element_type = ; +// using reference = ; +// iterator begin(); +// }; + +template +struct ArrayEngine +{ + using Storage = typename conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + using iterator = typename Storage::iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + Storage storage_; + + CUTE_HOST_DEVICE constexpr auto begin() const { return storage_.begin(); } + CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); } +}; + +template +struct ViewEngine +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + iterator storage_; + + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } + CUTE_HOST_DEVICE constexpr iterator & begin() { return storage_; } +}; + +template +struct ConstViewEngine +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + iterator storage_; + + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } +}; + +// +// Tensor +// + +template +struct Tensor +{ + using iterator = typename Engine::iterator; + using value_type = typename Engine::value_type; + using element_type = typename Engine::element_type; + using reference = typename Engine::reference; + + using engine_type = Engine; + using layout_type = Layout; + + CUTE_HOST_DEVICE constexpr + Tensor() {} + + template + CUTE_HOST_DEVICE constexpr + Tensor(Ptr const& ptr, Layout const& layout) + : rep_(layout, ptr) { + } + + // + // Accessors + // + + static constexpr int rank = Layout::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + tensor() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return get<0>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() const { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() const { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout().shape(); + } + + CUTE_HOST_DEVICE constexpr + auto + size() const { + return cute::size(shape()); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return layout().stride(); + } + + // + // Indexing op() and op[] + // + + // Index into this tensor like an array by computing the offset via layout() + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) const { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + // op() convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) { + return operator()(make_coord(c0,c1,cs...)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) { + return make_tensor(data(), layout().compose(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return make_tensor(data(), layout().compose(layouts...)); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) { + return make_tensor(data(), layout().tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return make_tensor(data(), layout().tile(layouts...)); + } + + // + // Utility + // + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(Int const& linear_idx) const { + return layout().get_1d_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(Int const& linear_idx) const { + return layout().get_hier_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(Int const& linear_idx) const { + return layout().get_flat_coord(linear_idx); + } + + cute::tuple rep_; +}; + +template +struct is_tensor : false_type {}; +template +struct is_tensor> : true_type {}; + +// Customization point for creation of owning and non-owning Tensors +template +struct MakeTensor +{ + template ::value && + is_layout::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(Layout const& layout) const + { + static_assert(is_static::value, "Dynamic owning tensors not supported"); + using Engine = ArrayEngine>; + return Tensor(); + } + + template ::value && + is_layout::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(T const& iter, Layout const& layout) + { + using Engine = ViewEngine; + return Tensor(iter, layout); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(LayoutArg const& arg, LayoutArgs const&... args) const + { + return operator()(make_layout(arg, args...)); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(T const& iter, LayoutArg const& arg, LayoutArgs const&... args) + { + return operator()(iter, make_layout(arg, args...)); + } +}; + +// +// make_tensor +// + +// Make an owning Tensor that will allocate a static array +// e.g. make_tensor(Int<12>{}) +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Args const&... args) +{ + return MakeTensor{}(args...); +} + +// Make a non-owning Tensor that will use a pointer (view) +// e.g. make_tensor(vec.data(), 12) +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& iter, Args const&... args) +{ + return MakeTensor{}(iter, args...); +} + +// +// make_tensor_like +// Make a register tensor the same type and shape and (if possible) order as another tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Layout const& layout) +{ + return make_tensor(make_layout_like(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Tensor const& tensor) +{ + return make_tensor_like(tensor.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Tensor const& tensor) +{ + return make_tensor_like(tensor.layout()); +} + +// +// make_fragment_like -- +// Make a tensor the same shape and (if possible) order as another tensor, with special +// consideration of the 0th mode. The 0th mode is commonly used for MMA_Atoms or Copy_Atoms +// so this allocates the 0th mode with LayoutLeft regardless of the reference layout. +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + return make_tensor(make_fragment_like(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + return make_fragment_like(tensor.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + return make_fragment_like(tensor.layout()); +} + +// +// make_counting_tensor +// Make a tensor from a layout by binding it to a counting iter with 0-offset of the same profile as the codomain. +// + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +make_counting_tensor(Layout const& layout) +{ + return make_tensor(make_inttuple_iter(repeat_like(coshape(layout), Int<0>{})), layout); +} + +// +// make_identity_tensor +// Make a tensor that maps coordinates within a shape to themselves. +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_tensor(Shape const& shape) +{ + return make_counting_tensor(make_identity_layout(shape)); +} + +// +// Utilities +// + +// Return the subtensor of a mode +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +tensor(Tensor&& tensor) +{ + return static_cast(tensor); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +tensor(Tensor&& tensor) +{ + return make_tensor(static_cast(tensor).data(), get(tensor.layout())); +} + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(Tensor const& tensor) +{ + return layout(tensor.layout()); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Tensor const& tensor) +{ + return shape(tensor.layout()); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Tensor const& tensor) +{ + return stride(tensor.layout()); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(Tensor const& tensor) +{ + return size(tensor.layout()); +} + +// Return the rank of a mode +template +CUTE_HOST_DEVICE constexpr +auto +rank(Tensor const& tensor) +{ + return rank(tensor.layout()); +} + +// Return the depth of a mode +template +CUTE_HOST_DEVICE constexpr +auto +depth(Tensor const& tensor) +{ + return depth(tensor.layout()); +} + +// +// Operations to manipulate Tensors like a Layout +// + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor&& tensor) +{ + return make_tensor(static_cast(tensor).data(), flatten(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor&& tensor) +{ + return make_tensor(static_cast(tensor).data(), coalesce(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor&& tensor, Profile const& profile) +{ + return make_tensor(static_cast(tensor).data(), coalesce(tensor.layout(), profile)); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor&& tensor) +{ + return make_tensor(static_cast(tensor).data(), filter_zeros(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor&& tensor) +{ + return make_tensor(static_cast(tensor).data(), filter(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor&& tensor, Profile const& profile) +{ + return make_tensor(static_cast(tensor).data(), filter(tensor.layout(), profile)); +} + +// Return a tensor with the same shape as input but offset by a given coordinate +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Tensor&& tensor) +{ + auto [layout, ptr_offset] = domain_offset(coord, tensor.layout()); + return make_tensor(static_cast(tensor).data() + ptr_offset, layout); +} + +// Group the modes [B,E) into a single mode +// e.g. group<2,4>(make_tensor(Layout>{})) +// => make_tensor(Layout,_5,_6>>{}) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor&& tensor) +{ + return make_tensor(static_cast(tensor).data(), + group(tensor.layout())); +} + +// Return the subtensor of a range of modes +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +take(Tensor&& tensor) +{ + return make_tensor(static_cast(tensor).data(), take(tensor.layout())); +} + +// +// Recast +// + +// NOTE: This is very dangerous to do +// -- doesn't check dynamic integer divisibility +// -- doesn't check alignment + +template +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor&& tensor) +{ + using OldType = typename remove_cvref_t::value_type; + auto old_layout = tensor.layout(); + auto new_layout = recast_layout(old_layout); + + // If this is an upcast of a normal Layout with static negative strides, then offset as well + if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { + auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); + auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); + auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); + + return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); + } else { + return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max_common_vector +// + +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the tensors of @a a and @a b. This is, + * the number of elements that could reasonably be vectorized into a single load/store. + * + * @returns Int with N >= 0 + * + * A return value of Int<0> indicates that no such conclusion can be made and no + * vectorization should be attempted. + * + * Note that the return value does NOT include alignment concerns such as the pointer value and + * the divisbility of dynamic strides. + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Tensor const& a, + Tensor const& b) +{ + using SrcType = typename Tensor::value_type; + using DstType = typename Tensor::value_type; + using SrcRef = typename Tensor::reference; + using DstRef = typename Tensor::reference; + + // Determine if vectorization candidates at all + if constexpr (// Should be the same value_types, else the copy is also performing a cast + sizeof_bits_v == sizeof_bits_v && + // The types should be trivially copyable so that vectorization is valid + is_trivially_copyable::value && + is_trivially_copyable::value && + // Should be load/storing real data, rather than implicit iterators or such + is_reference::value && + is_reference::value) + { + return max_common_vector(a.layout(), b.layout()); + } else { + return Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/* Return a layout that points to the maximum number of contiguous elements + * that logically correspond in the tensors of @a a and @a b. This is, + * the elements that could reasonably be "vectorized" into a single load/store. + * + * @returns Layout R such that composition(a.layout(), R) and composition(b.layout(), R) + * are both identity Layouts. + * + * Note that the returned layout does NOT include alignment concerns such as the pointer value and + * the divisbility of dynamic strides. + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Tensor const& a, + Tensor const& b) +{ + using SrcType = typename Tensor::value_type; + using DstType = typename Tensor::value_type; + using SrcRef = typename Tensor::reference; + using DstRef = typename Tensor::reference; + + // Determine if vectorization candidates at all + if constexpr (// Should be the same value_types, else the copy is also performing a cast + sizeof_bits_v == sizeof_bits_v && + // The types should be trivially copyable so that vectorization is valid + is_trivially_copyable::value && + is_trivially_copyable::value && + // Should be load/storing real data, rather than implicit iterators or such + is_reference::value && + is_reference::value) + { + return max_common_layout(a.layout(), b.layout()); + } else { + return Layout<_1,_0>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Key algebraic operations -- Divide and Product +// + +// Apply a Tiler to the Tensor. +// +// Consider a Tensor with shape (A,B,x,y) +// And a Tiler that is: +// +// * A Layout with shape (BLK_A,BLK_B) +// ** Result Tensor shape ((BLK_A,BLK_B),Rest). +// ** That is, the Tensor and Tile are treated as 1D for the tiling. +// ** See logical_divide(Layout,Layout) +// +// * A Tile with shape +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Each mode of the Tile is applied to the corresponding mode of the Tensor. +// ** See logical_divide(Layout,Tuple) +// +// * A Shape (BLK_A,BLK_B) +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Equivalent to applying Tile. +// ** See logical_divide(Layout,Tuple) and logical_divide(Layout,Int) +// +// Note that the Tile/Shape Tilers must be weakly_congruent to the Tensor +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + logical_divide(tensor.layout(), tiler)); +} + +// zipped_divide is logical_divide with Tiler modes and Rest modes gathered together: (Tiler,Rest) +// When Tiler is Layout, this has no effect as logical_divide results in the same. +// When Tiler is Tile or Shape, this zips modes into standard form ((BLK_A,BLK_B),(a,b,x,y)) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + zipped_divide(tensor.layout(), tiler)); +} + +// tiled_divide is zipped_divide with the second output mode flattened ((BLK_A,BLK_B),a,b,x,y) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + tiled_divide(tensor.layout(), tiler)); +} + +// flat_divide is zipped_divide with the both modes flattened (BLK_A,BLK_B,a,b,x,y) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + flat_divide(tensor.layout(), tiler)); +} + +// logical_product on a Tensor doesn't make sense since it often increases cosize +// though this might make sense for creating Tensors with broadcasted (stride-0) modes + +// +// Tensor partitioning utilities +// + +// Apply a Tiler to the Tensor, then slice out one of those tiles by slicing into the "Rest" modes. +// With an inner_partition, you get everything that's inside the Tiler. Everything that the Tiler is pointing to. +// Split the modes of tensor according to the Tiler +// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// Then slice into the second mode (the "Rest" mode) with Coord +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +inner_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) +{ + auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + + // The coord slices into the second mode (the "rest" mode), flatten the first + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;; + return tensor_tiled(repeat(_), append(coord,_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(repeat(_), coord); + } +} + +// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes. +// With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor. +// Split the modes of tensor according to the Tiler +// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// Then slice into the first mode (the "Tile" mode) with Coord +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +outer_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) +{ + auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value; + + // The coord slices into the first mode (the "tile" mode), flatten the second + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + return tensor_tiled(append(coord,_), repeat(_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(coord, repeat(_)); + } +} + +// Tile a tensor according to @a tiler and use @a coord to index into the remainder, keeping the tile. +// This is typical at the CTA level where tiles of data are extracted: +// Tensor data = ... // ( M, N) +// Tensor cta_data = local_tile(data, Shape<_32,_64>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord) // coord to slice into "remainder" +{ + return inner_partition(static_cast(tensor), + tiler, + coord); +} + +// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience +// when using projections of the same tiler. +// This is typical at the CTA level where tiles of data are extracted as projections: +// Tensor dataA = ... // (M,K) +// Tensor dataB = ... // (N,K) +// Tensor dataC = ... // (M,N) +// auto cta_tiler = Shape<_32, _64, _4>{}; +// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); +// Tensor ctaA = local_tile(dataA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (_32,_4,k) +// Tensor ctaB = local_tile(dataA, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (_64,_4,k) +// Tensor ctaC = local_tile(dataA, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (_32,_64) +template >::value)> +CUTE_HOST_DEVICE +auto +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord, // coord to slice into "remainder" + Proj const& proj) // projection to apply to tiler and coord +{ + return local_tile(static_cast(tensor), + dice(proj, tiler), + dice(proj, coord)); +} + +// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index. +// This is typical at the Thread level where data is partitioned across repeated patterns of threads: +// Tensor data = ... // (_16,_64) +// Tensor thr_data = local_partition(data, Layout>{}, thr_idx); // ( _8, _4) +template >::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index) // index to slice for +{ + static_assert(is_integral::value); + return outer_partition(static_cast(tensor), + product_each(shape(tile)), + tile.get_flat_coord(index)); +} + +// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience +// when using projections of the same tiler. +// This is typical at the Thread level where data is partitioned across projected layouts of threads: +// Tensor dataA = ... // (M,K) +// Tensor dataB = ... // (N,K) +// Tensor dataC = ... // (M,N) +// auto thr_layout = Layout, Stride<_16,_1,_0>>{}; +// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1) +// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1) +// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16) +template >::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index, // index to slice for + Projection const& proj) +{ + return local_partition(static_cast(tensor), + dice(proj, tile), + index); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Tensor const& tensor) +{ + print(tensor.data()); print(" o "); print(tensor.layout()); +} + +template +CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor, bool print_type = true) +{ + if (print_type) { + print(tensor); print(":\n"); + } + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + pretty_print(tensor(m)); + printf("\n"); + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + pretty_print(tensor(m,n)); + } + printf("\n"); + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor(tensor(_,_,0), false); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); + print_tensor(tensor(_,_,k), false); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor(tensor(_,_,_,0), false); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); + print_tensor(tensor(_,_,_,p), false); + } + } +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) +{ + int digits = 9; + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + os << std::setw(digits) << tensor(m) << std::endl; + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + os << std::setw(digits) << tensor(m,n); + } + os << std::endl; + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor_os(os, tensor(_,_,0)); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; + print_tensor_os(os, tensor(_,_,k)); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor_os(os, tensor(_,_,_,0)); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; + print_tensor_os(os, tensor(_,_,_,p)); + } + } + + return os; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) +{ + os << tensor.layout() << std::endl; + return print_tensor_os(os, tensor); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute + +// +// Extended Engines +// + +#include +#include + +// +// Tensor Algorithms +// + +#include +#include +#include +#include +#include +#include +#include + +#include +#include diff --git a/server/punica_kernels/include/cutlass/cute/tensor_predicate.hpp b/server/punica_kernels/include/cutlass/cute/tensor_predicate.hpp new file mode 100644 index 00000000..68146470 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/tensor_predicate.hpp @@ -0,0 +1,79 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +template +struct ConstantTensor +{ + template + CUTE_HOST_DEVICE constexpr + T const& + operator()(Coords const&...) const { + return val_; + } + + T val_; +}; + +struct TrivialPredTensor +{ + template + CUTE_HOST_DEVICE constexpr + true_type + operator()(Coords const&...) const { + return {}; + } +}; + +template +struct FunctionPredTensor +{ + CUTE_HOST_DEVICE constexpr + FunctionPredTensor(Fn const& fn) : fn_(fn) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coords const&... coords) const { + return fn_(coords...); + } + + Fn const& fn_; +}; + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/underscore.hpp b/server/punica_kernels/include/cutlass/cute/underscore.hpp new file mode 100644 index 00000000..212f42d7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/underscore.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +// For slicing +struct Underscore : Int<0> {}; + +CUTE_INLINE_CONSTANT Underscore _; + +// Convenient alias +using X = Underscore; + +// Treat Underscore as an integral like integral_constant +template <> +struct is_integral : true_type {}; + +template +struct is_underscore : false_type {}; +template <> +struct is_underscore : true_type {}; + +// Tuple trait for detecting static member element +template +struct has_elem : false_type {}; +template +struct has_elem : true_type {}; +template +struct has_elem::value> > + : has_elem > {}; +template +struct has_elem> + : disjunction, Elem>...> {}; + +// Tuple trait for detecting static member element +template +struct all_elem : false_type {}; +template +struct all_elem : true_type {}; +template +struct all_elem::value> > + : all_elem > {}; +template +struct all_elem> + : conjunction, Elem>...> {}; + +// Tuple trait for detecting Underscore member +template +using has_underscore = has_elem; + +template +using all_underscore = all_elem; + +template +using has_int1 = has_elem>; + +template +using has_int0 = has_elem>; + +// +// Slice keeps only the elements of Tuple B that are paired with an Underscore +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lift_slice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_slice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple{b}; + } else { + return cute::tuple<>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Entry point overrides the lifting so that slice(_,b) == b +template +CUTE_HOST_DEVICE constexpr +auto +slice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_slice(x,y); }); + } else if constexpr (is_underscore::value) { + return b; + } else { + return cute::tuple<>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Dice keeps only the elements of Tuple B that are paired with an Int +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lift_dice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_dice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple<>{}; + } else { + return cute::tuple{b}; + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Entry point overrides the lifting so that dice(1,b) == b +template +CUTE_HOST_DEVICE constexpr +auto +dice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_dice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple<>{}; + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +CUTE_HOST_DEVICE void print(Underscore const&) { + printf("_"); +} + +#if !defined(__CUDACC_RTC__) +CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { + return os << "_"; +} +#endif + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/util/debug.hpp b/server/punica_kernels/include/cutlass/cute/util/debug.hpp new file mode 100644 index 00000000..966bb115 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/util/debug.hpp @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/** + * \file + * \brief Debugging and logging functionality + */ + +#include + +#include + +namespace cute +{ + +/****************************************************************************** + * Debug and logging macros + ******************************************************************************/ + +/** + * Formats and prints the given message to stdout + */ +#if !defined(CUTE_LOG) +# if !defined(__CUDA_ARCH__) +# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__) +# else +# define CUTE_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, blockIdx.y, blockIdx.z, \ + threadIdx.x, threadIdx.y, threadIdx.z, \ + __VA_ARGS__); +# endif +#endif + +/** + * Formats and prints the given message to stdout only if DEBUG is defined + */ +#if !defined(CUTE_LOG_DEBUG) +# ifdef DEBUG +# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__) +# else +# define CUTE_LOG_DEBUG(format, ...) +# endif +#endif + +/** + * \brief Perror macro with exit + */ +#if !defined(CUTE_ERROR_EXIT) +# define CUTE_ERROR_EXIT(e) \ + do { \ + cudaError_t code = (e); \ + if (code != cudaSuccess) { \ + fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \ + __FILE__, __LINE__, #e, \ + cudaGetErrorName(code), cudaGetErrorString(code)); \ + fflush(stderr); \ + exit(1); \ + } \ + } while (0) +#endif + +#if !defined(CUTE_CHECK_LAST) +# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize()) +#endif + +#if !defined(CUTE_CHECK_ERROR) +# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e) +#endif + +// A dummy function that uses compilation failure to print a type +template +CUTE_HOST_DEVICE void +print_type() { + static_assert(sizeof...(T) < 0, "Printing type T."); +} + +template +CUTE_HOST_DEVICE void +print_type(T&&...) { + static_assert(sizeof...(T) < 0, "Printing type T."); +} + +// +// Device-specific helpers +// +// e.g. +// if (thread0()) print(...); +// if (block0()) print(...); +// if (thread(42)) print(...); + +CUTE_HOST_DEVICE +bool +block(int bid) +{ +#if defined(__CUDA_ARCH__) + return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid; +#else + return true; +#endif +} + +CUTE_HOST_DEVICE +bool +thread(int tid, int bid) +{ +#if defined(__CUDA_ARCH__) + return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid); +#else + return true; +#endif +} + +CUTE_HOST_DEVICE +bool +thread(int tid) +{ + return thread(tid,0); +} + +CUTE_HOST_DEVICE +bool +thread0() +{ + return thread(0,0); +} + +CUTE_HOST_DEVICE +bool +block0() +{ + return block(0); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/util/print.hpp b/server/punica_kernels/include/cutlass/cute/util/print.hpp new file mode 100644 index 00000000..6463e868 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/util/print.hpp @@ -0,0 +1,205 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// +// CUDA compatible print and printf +// + +namespace cute +{ + +CUTE_HOST_DEVICE +int +num_digits(int x) +{ + return (x < 10 ? 1 : + (x < 100 ? 2 : + (x < 1000 ? 3 : + (x < 10000 ? 4 : + (x < 100000 ? 5 : + (x < 1000000 ? 6 : + (x < 10000000 ? 7 : + (x < 100000000 ? 8 : + (x < 1000000000 ? 9 : + 10))))))))); +} + +// +// print dispatcher +// + +CUTE_HOST_DEVICE +void +print(char c) { + printf("%c", c); +} + +CUTE_HOST_DEVICE +void +print(signed char a) { + printf("%d", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(unsigned char a) { + printf("%u", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(short a) { + printf("%hd", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned short a) { + printf("%hu", a); +} + +CUTE_HOST_DEVICE +void +print(int a) { + printf("%d", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned int a) { + printf("%u", a); +} + +CUTE_HOST_DEVICE +void +print(long a) { + printf("%ld", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned long a) { + printf("%lu", a); +} + +CUTE_HOST_DEVICE +void +print(long long a) { + printf("%lld", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned long long a) { + printf("%llu", a); +} + +CUTE_HOST_DEVICE +void +print(float a) { + printf("%f", a); +} + +CUTE_HOST_DEVICE +void +print(double a) { + printf("%f", a); +} + +template +CUTE_HOST_DEVICE +void +print(char const* format, T const&... t) { + printf(format, t...); +} + +CUTE_HOST_DEVICE +void +print(char const* format) { + printf("%s", format); +} + +// +// pretty printing +// + +template +CUTE_HOST_DEVICE void +pretty_print(T const& v) { + printf(" "); print(v); +} + +CUTE_HOST_DEVICE void +pretty_print(bool const& v) { + printf("%*d", 3, int(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(int32_t const& v) { + printf("%*d", 5, v); +} + +CUTE_HOST_DEVICE void +pretty_print(uint32_t const& v) { + printf("%*d", 5, v); +} + +CUTE_HOST_DEVICE void +pretty_print(int64_t const& v) { + printf("%*lld", 5, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(uint64_t const& v) { + printf("%*llu", 5, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(half_t const& v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(float const& v) { + printf("%*.2e", 10, v); +} + +CUTE_HOST_DEVICE void +pretty_print(double const& v) { + printf("%*.3e", 11, v); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cute/util/type_traits.hpp b/server/punica_kernels/include/cutlass/cute/util/type_traits.hpp new file mode 100644 index 00000000..a8cab903 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cute/util/type_traits.hpp @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#include +#include +#include +#include +#else +#include +#include // tuple_size, tuple_element +#include // ptrdiff_t +#include // uintptr_t +#include // numeric_limits +#endif + +#include + +namespace cute +{ + using CUTE_STL_NAMESPACE::enable_if; + using CUTE_STL_NAMESPACE::enable_if_t; +} + +#define __CUTE_REQUIRES(...) typename cute::enable_if<(__VA_ARGS__)>::type* = nullptr +#define __CUTE_REQUIRES_V(...) typename cute::enable_if::type* = nullptr + +namespace cute +{ + +// +using CUTE_STL_NAMESPACE::conjunction; +using CUTE_STL_NAMESPACE::conjunction_v; + +using CUTE_STL_NAMESPACE::disjunction; +using CUTE_STL_NAMESPACE::disjunction_v; + +using CUTE_STL_NAMESPACE::negation; +using CUTE_STL_NAMESPACE::negation_v; + +using CUTE_STL_NAMESPACE::void_t; +using CUTE_STL_NAMESPACE::is_void_v; + +using CUTE_STL_NAMESPACE::is_base_of; +using CUTE_STL_NAMESPACE::is_base_of_v; + +using CUTE_STL_NAMESPACE::is_const; +using CUTE_STL_NAMESPACE::is_const_v; +using CUTE_STL_NAMESPACE::is_volatile; +using CUTE_STL_NAMESPACE::is_volatile_v; + +// using CUTE_STL_NAMESPACE::true_type; +// using CUTE_STL_NAMESPACE::false_type; + +using CUTE_STL_NAMESPACE::conditional; +using CUTE_STL_NAMESPACE::conditional_t; + +using CUTE_STL_NAMESPACE::remove_const_t; +using CUTE_STL_NAMESPACE::remove_cv_t; +using CUTE_STL_NAMESPACE::remove_reference_t; + +using CUTE_STL_NAMESPACE::extent; +using CUTE_STL_NAMESPACE::remove_extent; + +using CUTE_STL_NAMESPACE::decay; +using CUTE_STL_NAMESPACE::decay_t; + +using CUTE_STL_NAMESPACE::is_lvalue_reference; +using CUTE_STL_NAMESPACE::is_lvalue_reference_v; + +using CUTE_STL_NAMESPACE::is_reference; +using CUTE_STL_NAMESPACE::is_trivially_copyable; + +using CUTE_STL_NAMESPACE::is_same; +using CUTE_STL_NAMESPACE::is_same_v; + +using CUTE_STL_NAMESPACE::is_arithmetic; +using CUTE_STL_NAMESPACE::is_unsigned; +using CUTE_STL_NAMESPACE::is_unsigned_v; +using CUTE_STL_NAMESPACE::is_signed; +using CUTE_STL_NAMESPACE::is_signed_v; + +using CUTE_STL_NAMESPACE::make_signed; +using CUTE_STL_NAMESPACE::make_signed_t; + +// using CUTE_STL_NAMESPACE::is_integral; +template +using is_std_integral = CUTE_STL_NAMESPACE::is_integral; + +using CUTE_STL_NAMESPACE::is_empty; +using CUTE_STL_NAMESPACE::is_empty_v; + +using CUTE_STL_NAMESPACE::invoke_result_t; + +using CUTE_STL_NAMESPACE::common_type; +using CUTE_STL_NAMESPACE::common_type_t; + +using CUTE_STL_NAMESPACE::remove_pointer; +using CUTE_STL_NAMESPACE::remove_pointer_t; + +// +using CUTE_STL_NAMESPACE::declval; + +template +constexpr T&& forward(remove_reference_t& t) noexcept +{ + return static_cast(t); +} + +template +constexpr T&& forward(remove_reference_t&& t) noexcept +{ + static_assert(! is_lvalue_reference_v, "T cannot be an lvalue reference (e.g., U&)."); + return static_cast(t); +} + +template +constexpr remove_reference_t&& move(T&& t) noexcept +{ + return static_cast&&>(t); +} + +// +using CUTE_STL_NAMESPACE::numeric_limits; + +// +using CUTE_STL_NAMESPACE::ptrdiff_t; + +// +using CUTE_STL_NAMESPACE::uintptr_t; + +// C++20 +// using std::remove_cvref; +template +struct remove_cvref { + using type = remove_cv_t>; +}; + +// C++20 +// using std::remove_cvref_t; +template +using remove_cvref_t = typename remove_cvref::type; + +// +// dependent_false +// +// @brief An always-false value that depends on one or more template parameters. +// See +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf +// https://github.com/cplusplus/papers/issues/572 +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +template +inline constexpr bool dependent_false = false; + +// +// tuple_size, tuple_element +// +// @brief CuTe-local tuple-traits to prevent conflicts with other libraries. +// For cute:: types, we specialize std::tuple-traits, which is explicitly allowed. +// cute::tuple, cute::array, cute::array_subbyte, etc +// But CuTe wants to treat some external types as tuples as well. For those, +// we specialize cute::tuple-traits to avoid polluting external traits. +// dim3, uint3, etc + +template +struct tuple_size; + +template +struct tuple_size::type>> : CUTE_STL_NAMESPACE::integral_constant::value> {}; + +// S = : std::integral_constant::value> {}; + +template +constexpr size_t tuple_size_v = tuple_size::value; + +template +struct tuple_element; + +template +struct tuple_element::type>> : CUTE_STL_NAMESPACE::tuple_element {}; + +template +using tuple_element_t = typename tuple_element::type; + +// +// is_valid +// + +namespace detail { + +template ()(declval()...))> +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(int) { return CUTE_STL_NAMESPACE::true_type{}; } + +template +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(...) { return CUTE_STL_NAMESPACE::false_type{}; } + +template +struct is_valid_fn { + template + CUTE_HOST_DEVICE constexpr auto + operator()(Args&&...) const { return is_valid_impl(int{}); } +}; + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&) { + return detail::is_valid_fn{}; +} + +template +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&, Args&&...) { + return detail::is_valid_impl(int{}); +} + +} // end namespace cute diff --git a/server/punica_kernels/include/cutlass/cutlass/aligned_buffer.h b/server/punica_kernels/include/cutlass/cutlass/aligned_buffer.h new file mode 100644 index 00000000..d471eda7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/aligned_buffer.h @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief AlignedBuffer is a container for trivially copyable elements suitable for use in + unions and shared memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Modifies semantics of cutlass::Array<> to provide guaranteed alignment. +template < + typename T, + int N, + int Align = 16 +> +struct AlignedBuffer { + + /// Internal storage type + using Storage = uint8_t; + + /// Number of logical elements held in buffer + static int const kCount = N; + + /// Alignment requirement in bytes + static int const kAlign = Align; + + /// Number of storage elements + static int const kBytes = + (sizeof_bits::value * N + 7) / 8; + +private: + + /// Internal storage + alignas(Align) Storage storage[kBytes]; + +public: + + // + // C++ standard members + // + + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type *pointer; + typedef value_type const * const_pointer; + + using Array = Array; + using reference = typename Array::reference; + using const_reference = typename Array::const_reference; + +public: + + CUTLASS_HOST_DEVICE + pointer data() { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + const_pointer data() const { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + Storage * raw_data() { + return storage; + } + + CUTLASS_HOST_DEVICE + Storage const * raw_data() const { + return storage; + } + + + CUTLASS_HOST_DEVICE + constexpr bool empty() const { + return !kCount; + } + + CUTLASS_HOST_DEVICE + constexpr size_type size() const { + return kCount; + } + + CUTLASS_HOST_DEVICE + constexpr size_type max_size() const { + return kCount; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/arch.h b/server/punica_kernels/include/cutlass/cutlass/arch/arch.h new file mode 100644 index 00000000..f3e7dc7d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/arch.h @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines tags for architecture-specific configurations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) + +/// Computes laneId within a warp +CUTLASS_DEVICE +int LaneId() { + int ret; + asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); + return ret; +} + +/// Computes SM number the thread is running on +CUTLASS_DEVICE +int SmId() { + int ret; + asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); + return ret; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +struct Sm50 { + static int const kMinComputeCapability = 50; +}; +struct Sm60 { + static int const kMinComputeCapability = 60; +}; +struct Sm61 { + static int const kMinComputeCapability = 61; +}; +struct Sm70 { + static int const kMinComputeCapability = 70; +}; +struct Sm72 { + static int const kMinComputeCapability = 72; +}; +struct Sm75 { + static int const kMinComputeCapability = 75; +}; +struct Sm80 { + static int const kMinComputeCapability = 80; +}; +struct Sm86 { + static int const kMinComputeCapability = 86; +}; +struct Sm89 { + static int const kMinComputeCapability = 89; +}; +struct Sm90 { + static int const kMinComputeCapability = 90; +}; + +/// Triggers a breakpoint on the device +CUTLASS_DEVICE +void device_breakpoint() { +#if defined(__CUDA_ARCH__) + asm volatile (" brkpt;\n"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/barrier.h b/server/punica_kernels/include/cutlass/cutlass/arch/barrier.h new file mode 100644 index 00000000..e9ec7013 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/barrier.h @@ -0,0 +1,582 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Barrier Operations on SM90+ +*/ + +#pragma once + +#include +#include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) +#define CUDA_BARRIER_ENABLED 1 +#else +#define CUDA_BARRIER_ENABLED 0 +#endif + +namespace cutlass { +/// @brief +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts +// This enum class specifies the NamedBarriers reserved by CUTLASS. +enum class ReservedNamedBarriers { + EpilogueBarrier = 0, + TransposeBarrier = 1, + TransformBarrier = 2, + StreamkBarrier0 = 3, + StreamkBarrier1 = 4 + , FirstUserBarrier = StreamkBarrier1 + 1 +}; + + +class NamedBarrier { + + // Data Members: + + // Range = [1 , NUM_THREADS_PER_CTA] + // Range % warp-size (i.e 32) == 0 + uint32_t const num_threads_; + + // Range : [0, 15] + // Note that should be set to the final barrier ID, including ReserveNamedBarrierCount should be considered + uint32_t const id_; + + public: + + // Constructor for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) + : num_threads_(num_threads), id_(static_cast(reserved_named_barriers)) {} + + // Constructor for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, uint32_t id = 0) + : num_threads_(num_threads), id_(id + ReservedNamedBarrierCount) { + CUTLASS_ASSERT(id + ReservedNamedBarrierCount <= HardwareMaxNumNamedBarriers && "Effective barrier_id should not exceed 16."); + } + + CUTLASS_DEVICE + void arrive_and_wait() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_and_wait_internal(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_internal(num_threads_, id_); + } + + CUTLASS_DEVICE + void sync() const { + NamedBarrier::arrive_and_wait(); + } + + // Static variants + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { + arrive_and_wait_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_and_wait_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, uint32_t barrier_id) { + arrive_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void sync(uint32_t num_threads, uint32_t barrier_id) { + sync_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void sync(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + sync_internal(num_threads, static_cast(reserved_named_barriers)); + } + + private: + CUTLASS_DEVICE + static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void sync_internal(uint32_t num_threads, uint32_t barrier_id) { + NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id); + } + + public: + // Currently we reserve 8 NamedBarriers for CUTLASS' own use cases, + // while leaving the renaming for general users. + static const uint32_t ReservedNamedBarrierCount = static_cast(ReservedNamedBarriers::FirstUserBarrier); + static const uint32_t HardwareMaxNumNamedBarriers = 16; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide arrive-wait behaviour. +// This is an extension to the Ampere arrive-wait barriers +// Note : Ampere arrive-wait Barriers have a larger max-arrive count (2^30) than Hopper arrive-wait Barriers (2^20). +struct ClusterBarrier { + + using ValueType = uint64_t; + +protected: + // Can never be initialized - can only be aliased to smem + ValueType barrier_; + +public: + + CUTLASS_DEVICE + ClusterBarrier() = delete; + + CUTLASS_DEVICE + void init(uint32_t arrive_count) const { + ClusterBarrier::init(&this->barrier_, arrive_count); + } + + CUTLASS_DEVICE + uint32_t test_wait(uint32_t phase, uint32_t pred=true) const { + return ClusterBarrier::test_wait(&this->barrier_, phase, pred); + } + + CUTLASS_DEVICE + uint32_t try_wait(uint32_t phase) const { + return ClusterBarrier::try_wait(&this->barrier_, phase); + } + + CUTLASS_DEVICE + void wait(uint32_t phase) const { + ClusterBarrier::wait(&this->barrier_, phase); + } + + // Barrier arrive on local smem + CUTLASS_DEVICE + void arrive() const { + ClusterBarrier::arrive(&this->barrier_); + } + + // Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive) + CUTLASS_DEVICE + void arrive(uint32_t cta_id, uint32_t pred = true ) const { + ClusterBarrier::arrive(&this->barrier_, cta_id, pred); + } + + // + // Static Versions + // + CUTLASS_DEVICE + static void init(ValueType const* smem_ptr, uint32_t arrive_count) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.init.shared::cta.b64 [%1], %0; \n" + "}" + : + : "r"(arrive_count), "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Static version of wait - in case we don't want to burn a register + CUTLASS_DEVICE + static void wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + // Arbitrarily large timer value after which try-wait expires and re-tries. + uint32_t ticks = 0x989680; + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_addr), "r"(phase), "r"(ticks)); + +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static uint32_t test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + ".reg .pred P2; \n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase), "r"(pred)); + + return waitComplete; +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + return 0; + } + + CUTLASS_DEVICE + static uint32_t try_wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase)); + + return waitComplete; +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + return 0; + } + + // Static Predicated version of the above - in case we know the address. + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b32 remAddr32;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id), "r"(pred)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Barrier arrive on local smem + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.shared::cta.b64 _, [%0];\n\t" + "}" + : + : "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void invalidate(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.ival.shared::cta.b64 [%0]; \n\t" + "}" + : + : "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 also introduces a new type of cluster-barrier which supports sync. +// not just based on Arrive Count, but also transaction count (in bytes) +struct ClusterTransactionBarrier : public ClusterBarrier { + + CUTLASS_DEVICE + ClusterTransactionBarrier() = delete; + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + void arrive_and_expect_tx(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes); + } + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + void arrive_and_expect_tx(uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred = 1u) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes , cta_id, pred); + } + + // Performs an expected transaction bytes increment without doing an arrive operation + CUTLASS_DEVICE + void expect_transaction(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::expect_transaction(&this->barrier_, transaction_bytes); + } + + // Performs an expected transaction bytes decrement without doing an arrive operation + CUTLASS_DEVICE + void complete_transaction(uint32_t transaction_bytes, uint32_t pred = 1) const { + uint32_t cta_rank = cute::block_rank_in_cluster(); + ClusterTransactionBarrier::complete_transaction(&this->barrier_, cta_rank, transaction_bytes, pred); + } + + // Performs an expected transaction bytes decrement without doing an arrive operation + CUTLASS_DEVICE + void complete_transaction(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + ClusterTransactionBarrier::complete_transaction(&this->barrier_, dst_cta_id, transaction_bytes, pred); + } + + // + // Static Versions + // + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + static void arrive_and_expect_tx(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an arrive operation + expected transaction bytes increment for a remote cta_id in a Cluster + CUTLASS_DEVICE + static void arrive_and_expect_tx( + ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b32 remAddr32;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an expected transaction bytes increment without doing an arrive operation + CUTLASS_DEVICE + static void expect_transaction(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an expected transaction bytes decrement without doing an arrive operation + CUTLASS_DEVICE + static void complete_transaction( + ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + smem_addr = cute::set_block_rank(smem_addr, dst_cta_id); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // + // DEPRECATED APIs + // + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes) const { + arrive_and_expect_tx(transaction_bytes); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { + arrive_and_expect_tx(transaction_bytes, cta_id); + } + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE + void reset_bytes(uint32_t transaction_bytes) const { + expect_transaction(transaction_bytes); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { + complete_transaction(transaction_bytes, pred); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + complete_transaction(dst_cta_id, transaction_bytes, pred); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + arrive_and_expect_tx(smem_ptr, transaction_bytes); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { + arrive_and_expect_tx(smem_ptr, transaction_bytes, cta_id, pred); + } + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE + static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + expect_transaction(smem_ptr, transaction_bytes); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + static void commit(ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { + complete_transaction(smem_ptr, dst_cta_id, transaction_bytes, pred); + } +}; + +// Helps with visibility of barrier init operations across warps / cta / cluster +// Available as a separate function so as to batch inits across barriers and fence once +// Note : It must be composed with an appropriate sync instruction with the right scope +// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait() +CUTLASS_DEVICE +void fence_barrier_init() { +#if CUDA_BARRIER_ENABLED + asm volatile( + "{\n\t" + "fence.mbarrier_init.release.cluster; \n" + "}" + ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +// Issue a shared memory fence for async operations +CUTLASS_DEVICE +void fence_view_async_shared() { +#if CUDA_BARRIER_ENABLED + asm volatile ( + "{\n\t" + "fence.proxy.async.shared::cta; \n" + "}" + ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +// Arrive on completion of in-flight cp.async operations issued by the calling thread +CUTLASS_DEVICE +void cpasync_barrier_arrive(uint64_t const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t" + "}" + : + : "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // end namespace arch +} // end namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/cache_operation.h b/server/punica_kernels/include/cutlass/cutlass/arch/cache_operation.h new file mode 100644 index 00000000..9d2344bf --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/cache_operation.h @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Directives related to cache operations +*/ +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Controls PTX cache operations +struct CacheOperation { + enum Kind { + /// Cache at all levels - accessed again + Always, + /// Cache at global level + Global, + /// Streaming - likely to be accessed once + Streaming, + /// Indicates the line will not be used again + LastUse, + /// Don't cache, and fetch again + Volatile, + /// Write back at all coherent levels + WriteBack, + /// Write through to system memory + WriteThrough + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/memory.h b/server/punica_kernels/include/cutlass/cutlass/arch/memory.h new file mode 100644 index 00000000..71304516 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/memory.h @@ -0,0 +1,601 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Architecture-specific operators on memory +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/cache_operation.h" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Fragment type to store loaded data + typename AccessType, + /// The bytes of loading + int LoadBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always + > +struct global_load; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ + (__CUDACC_VER_MAJOR__ > 11)) && \ + defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) + #define CUTLASS_ENABLE_L2_PREFETCH 1 +#else + #define CUTLASS_ENABLE_L2_PREFETCH 0 +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// The redundant mov PTX instruction is used to enforce the compiler to +// keep the initializing code before ld.global +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 *data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %9, 0;\n" + " mov.b32 %0, %10;\n" + " mov.b32 %1, %11;\n" + " mov.b32 %2, %12;\n" + " mov.b32 %3, %13;\n" + " mov.b32 %4, %14;\n" + " mov.b32 %5, %15;\n" + " mov.b32 %6, %16;\n" + " mov.b32 %7, %17;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.L2::128B.v4.u32 {%4, %5, %6, %7}, [%18];\n" +#else + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n" +#endif + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), + "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) + : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), + "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), + "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 *data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %9, 0;\n" + " mov.b32 %0, %10;\n" + " mov.b32 %1, %11;\n" + " mov.b32 %2, %12;\n" + " mov.b32 %3, %13;\n" + " mov.b32 %4, %14;\n" + " mov.b32 %5, %15;\n" + " mov.b32 %6, %16;\n" + " mov.b32 %7, %17;\n" + " @p ld.global.lu.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.lu.v4.u32 {%4, %5, %6, %7}, [%18];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), + "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) + : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), + "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), + "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " mov.b32 %0, %6;\n" + " mov.b32 %1, %7;\n" + " mov.b32 %2, %8;\n" + " mov.b32 %3, %9;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n" +#else + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" +#endif + "}\n" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " mov.b32 %0, %6;\n" + " mov.b32 %1, %7;\n" + " mov.b32 %2, %8;\n" + " mov.b32 %3, %9;\n" + " @p ld.global.lu.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint2 &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " mov.b32 %0, %4;\n" + " mov.b32 %1, %5;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.v2.u32 {%0, %1}, [%2];\n" +#else + " @p ld.global.v2.u32 {%0, %1}, [%2];\n" +#endif + "}\n" + : "=r"(data.x), "=r"(data.y) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint2 &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " mov.b32 %0, %4;\n" + " mov.b32 %1, %5;\n" + " @p ld.global.lu.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data.x), "=r"(data.y) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + unsigned &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b32 %0, %3;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.u32 %0, [%1];\n" +#else + " @p ld.global.u32 %0, [%1];\n" +#endif + "}\n" + : "=r"(data) + : "l"(ptr), "r"((int)pred_guard), "r"(data)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + unsigned &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b32 %0, %3;\n" + " @p ld.global.lu.u32 %0, [%1];\n" + "}\n" + : "=r"(data) + : "l"(ptr), "r"((int)pred_guard), "r"(data)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint16_t &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b16 %0, %3;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.u16 %0, [%1];\n" +#else + " @p ld.global.u16 %0, [%1];\n" +#endif + "}\n" + : "=h"(data) + : "l"(ptr), "r"((int)pred_guard), "h"(data)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint16_t &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b16 %0, %3;\n" + " @p ld.global.lu.u16 %0, [%1];\n" + "}\n" + : "=h"(data) + : "l"(ptr), "r"((int)pred_guard), "h"(data)); + } +}; + +template +struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + if (pred_guard) D = *(reinterpret_cast(ptr)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Fragment type to store data + typename AccessType, + /// The bytes of storing + int StoreBytes + > +struct global_store; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint4 const *data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" + " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" + " @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n" + " @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n" + "}\n" + : + : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), + "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), + "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w), + "l"(((uint8_t *)ptr) + 32), + "r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w), + "l"(((uint8_t *)ptr) + 48), + "r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[3].w)); + } +}; + + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint4 const *data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" + " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" + "}\n" + : + : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), + "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), + "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint4 const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" + "}\n" + : + : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint2 const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " @p st.global.v2.u32 [%0], {%1, %2};\n" + "}\n" + : + : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint32_t const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " @p st.global.u32 [%0], %1;\n" + "}\n" + : + : "l"(ptr), "r"(data), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + uint16_t const &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " @p st.global.u16 [%0], %1;\n" + "}\n" + : + : "l"(ptr), "h"(data), "r"((int)pred_guard)); + } +}; + +template +struct global_store { + CUTLASS_DEVICE + global_store(AccessType const &D, void *ptr, bool pred_guard) { + if (pred_guard) *(reinterpret_cast(ptr)) = D; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// ld.shared +template +CUTLASS_DEVICE +void shared_load(void *dst, uint32_t ptr); + +/// ld.shared - 16b +template <> +CUTLASS_DEVICE +void shared_load<2>(void *dst, uint32_t ptr) { + asm volatile("ld.shared.u16 %0, [%1];\n" + : "=h"(*reinterpret_cast(dst)) + : "r"(ptr)); +} + +/// ld.shared - 32b +template <> +CUTLASS_DEVICE +void shared_load<4>(void *dst, uint32_t ptr) { + asm volatile("ld.shared.u32 %0, [%1];\n" + : "=r"(*reinterpret_cast(dst)) + : "r"(ptr)); +} + +/// ld.shared - 64b +template <> +CUTLASS_DEVICE +void shared_load<8>(void *dst, uint32_t ptr) { + uint2 *dst_u64 = reinterpret_cast(dst); + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" + : + "=r"(dst_u64->x), + "=r"(dst_u64->y) + : "r"(ptr)); +} + +/// ld.shared - 128b +template <> +CUTLASS_DEVICE +void shared_load<16>(void *dst, uint32_t ptr) { + uint4 *dst_u128 = reinterpret_cast(dst); + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : + "=r"(dst_u128->x), + "=r"(dst_u128->y), + "=r"(dst_u128->z), + "=r"(dst_u128->w) + : "r"(ptr)); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// st.shared +template +CUTLASS_DEVICE +void shared_store(uint32_t ptr, void const *src); + +/// st.shared - 16b +template <> +CUTLASS_DEVICE +void shared_store<2>(uint32_t ptr, void const *src) { + asm volatile("st.shared.u16 [%0], %1;\n" + : : + "r"(ptr), + "h"(*reinterpret_cast(src)) + ); +} + +/// st.shared - 32b +template <> +CUTLASS_DEVICE +void shared_store<4>(uint32_t ptr, void const *src) { + asm volatile("st.shared.u32 [%0], %1;\n" + : : + "r"(ptr), + "r"(*reinterpret_cast(src)) + ); +} + +/// st.shared - 64b +template <> +CUTLASS_DEVICE +void shared_store<8>(uint32_t ptr, void const *src) { + uint2 const *dst_u64 = reinterpret_cast(src); + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : : + "r"(ptr), + "r"(dst_u64->x), + "r"(dst_u64->y) + ); +} + +/// st.shared - 128b +template <> +CUTLASS_DEVICE +void shared_store<16>(uint32_t ptr, void const *src) { + uint4 const *dst_u128 = reinterpret_cast(src); + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + : : + "r"(ptr), + "r"(dst_u128->x), + "r"(dst_u128->y), + "r"(dst_u128->z), + "r"(dst_u128->w) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/memory_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/memory_sm75.h b/server/punica_kernels/include/cutlass/cutlass/arch/memory_sm75.h new file mode 100644 index 00000000..2a1d5ea4 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/memory_sm75.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Architecture-specific operators on memory added for SM75 +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cute/arch/copy_sm75.hpp" +#include "cute/arch/util.hpp" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Layout of destination matrix (column-major implies transpose) + typename Layout, + /// .x1, .x2, or .x4 + int MatrixCount +> +inline __device__ void ldsm(Array & D, void const* ptr); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Determine the appropriate way to target PTX's "ldmatrix" instruction. +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CUTLASS helper to get SMEM pointer +inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { + return cute::cast_smem_ptr_to_uint(ptr); +} + +/// CUTLASS helper to get SMEM pointer +inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { + return cutlass_get_smem_pointer(const_cast(ptr)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void ldsm( + Array & D, + void const* ptr) { + + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + + unsigned addr = cutlass_get_smem_pointer(ptr); + + int x; + asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); + reinterpret_cast(D) = x; + + #else + + CUTLASS_UNUSED(D); + CUTLASS_UNUSED(ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void ldsm( + Array & D, + void const* ptr) { + + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + + unsigned addr = cutlass_get_smem_pointer(ptr); + + int x, y; + asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); + reinterpret_cast(D) = make_int2(x, y); + + #else + + CUTLASS_UNUSED(D); + CUTLASS_UNUSED(ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void ldsm( + Array & D, + void const* ptr) { + + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + + unsigned addr = cutlass_get_smem_pointer(ptr); + + int x, y, z, w; + asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); + reinterpret_cast(D) = make_int4(x, y, z, w); + + #else + + CUTLASS_UNUSED(D); + CUTLASS_UNUSED(ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Transpose on 16b granularity +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void ldsm( + Array & D, + void const* ptr) { + + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + + unsigned addr = cutlass_get_smem_pointer(ptr); + + int x; + asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); + reinterpret_cast(D) = x; + + #else + + CUTLASS_UNUSED(D); + CUTLASS_UNUSED(ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void ldsm( + Array & D, + void const* ptr) { + + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + + unsigned addr = cutlass_get_smem_pointer(ptr); + + int x, y; + asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); + reinterpret_cast(D) = make_int2(x, y); + + #else + + CUTLASS_UNUSED(D); + CUTLASS_UNUSED(ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void ldsm( + Array & D, + void const* ptr) { + + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + + unsigned addr = cutlass_get_smem_pointer(ptr); + + int x, y, z, w; + asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); + reinterpret_cast(D) = make_int4(x, y, z, w); + + #else + + CUTLASS_UNUSED(D); + CUTLASS_UNUSED(ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct shared_load_op { + CUTLASS_DEVICE + shared_load_op(AccessType &D, void const *ptr) { + D = *reinterpret_cast(ptr); + } +}; + +template +CUTLASS_DEVICE void shared_load(AccessType &D, void const *ptr) { + shared_load_op(D, ptr); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct shared_load_op { + CUTLASS_DEVICE + shared_load_op(AccessType &D, void const *ptr) { + unsigned addr = cutlass_get_smem_pointer(ptr); + + uint4 v; + asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : + "=r"(v.x), "=r"(v.y), "=r"(v.z), "=r"(v.w) : "r"(addr)); + + D = reinterpret_cast(v); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct shared_load_op { + CUTLASS_DEVICE + shared_load_op(AccessType &D, void const *ptr) { + unsigned addr = cutlass_get_smem_pointer(ptr); + + uint2 v; + asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : + "=r"(v.x), "=r"(v.y) : "r"(addr)); + + D = reinterpret_cast(v); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/memory_sm80.h b/server/punica_kernels/include/cutlass/cutlass/arch/memory_sm80.h new file mode 100644 index 00000000..434fab03 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/memory_sm80.h @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Architecture-specific operators on memory added for SM80 +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/cache_operation.h" + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + #define CUDA_CP_ASYNC_ACTIVATED 1 +#else + #define CUDA_CP_ASYNC_ACTIVATED 0 +#endif + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initiates an asynchronous copy from global memory to shared memory. +/// +/// cp.async +/// +template < + /// Size of the access in bytes + int SizeInBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always> +struct cp_async; + +/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate +/// the entire transfer, zeros are written to SMEM if the guard predicate is false. +/// +/// cp.async +/// +template < + /// Size of the access in bytes + int SizeInBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always> +struct cp_async_zfill; + +/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate +/// the entire transfer, nans (0x7eff) are written to SMEM if the guard predicate is false. +/// +/// cp.async +/// +template < + /// Size of the access in bytes + int SizeInBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always> +struct cp_async_nan; + +/// Either 0 or 1 are written to SMEM based on input element type +/// Used for diagonal elements of triangular matrix of BLAS3 functions +/// +/// st.shared +/// +template < + /// Type of Element + typename Element, + /// If the data is for a Hermitian matrix diagonal + bool IsHermitianData = false> +struct cp_async_diag; + +static const uint32_t OOB_NAN_F16 = 0x7eff; +static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async { + + /// Copy + CUTLASS_DEVICE + cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + #if CUDA_CP_ASYNC_ACTIVATED + + // Make sure the size is supported. + static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), + "Size is not supported"); + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" +#else + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" +#endif + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + #endif + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async_zfill { + + /// Copy with zero fill + CUTLASS_DEVICE + cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { + #if CUDA_CP_ASYNC_ACTIVATED + + // Make sure the size is supported. + static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), + "Size is not supported"); + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + int src_in_bytes = (pred_guard ? SizeInBytes : 0); + + asm volatile( +#if CUTLASS_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#else + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#endif + "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + #endif + } +}; + +/// Partial specialization +template <> +struct cp_async_nan<16, CacheOperation::Always> { + static int const kSizeInBytes = 16; + + /// Copy with nan fill + CUTLASS_DEVICE + cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { + #if CUDA_CP_ASYNC_ACTIVATED + + static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, + OOB_NAN_F16x2, OOB_NAN_F16x2}; + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" +#else + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" +#endif + " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" + "}\n" + : + : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), + "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), + "r"(OOB_NAN_F16x8.w)); + + #else + + CUTLASS_UNUSED(smem_ptr); + CUTLASS_UNUSED(global_ptr); + CUTLASS_UNUSED(pred_guard); + CUTLASS_NOT_IMPLEMENTED(); + + #endif + } +}; + +/// Partial specialization to write one (1) +template +struct cp_async_diag { + using Element = Element_; + + CUTLASS_DEVICE + cp_async_diag(void *smem_ptr) { + #if CUDA_CP_ASYNC_ACTIVATED + + /// Values for the diagonal elements of the triangular input matrix + static __constant__ uint2 DIAG_DATA_DOUBLE_ONE = {0x3ff00000, 0x00000000}; + static __constant__ uint1 DIAG_DATA_FLOAT_ONE = {0x3f800000}; + static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + + if (platform::is_same>::value) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + : : + "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y), "r"(DIAG_DATA_DOUBLE_ONE.x), + "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); + } else if (platform::is_same>::value) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : : + "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x), "r"(DIAG_DATA_ZERO.x)); + } else if (platform::is_same::value) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : : + "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y),"r"(DIAG_DATA_DOUBLE_ONE.x)); + } else if (platform::is_same::value) { + asm volatile("st.shared.u32 [%0], %1;\n" + : : + "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x)); + } else { + CUTLASS_UNUSED(smem_int_ptr); + CUTLASS_NOT_IMPLEMENTED(); + } + + #else + + CUTLASS_UNUSED(smem_ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif + } +}; + +/// Partial specialization to write zero for the imaginary part of Hermitian data +template +struct cp_async_diag { + using Element = Element_; + + CUTLASS_DEVICE + cp_async_diag(void *smem_ptr) { + #if CUDA_CP_ASYNC_ACTIVATED + + /// Values for the diagonal elements of the triangular input matrix + static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + + if (platform::is_same>::value) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : : + "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); + } else if (platform::is_same>::value) { + asm volatile("st.shared.u32 [%0], %1;\n" + : : + "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x)); + } else { + CUTLASS_UNUSED(smem_int_ptr); + CUTLASS_NOT_IMPLEMENTED(); + } + + #else + + CUTLASS_UNUSED(smem_ptr); + CUTLASS_NOT_IMPLEMENTED(); + + #endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async { + + /// Copy + CUTLASS_DEVICE + cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + #if CUDA_CP_ASYNC_ACTIVATED + + static_assert(SizeInBytes == 16, + "cp.async only supports CacheOperation::Global when access size is 16B."); + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" +#else + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" +#endif + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + #endif + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct cp_async_zfill { + + /// Copy with zero fill + CUTLASS_DEVICE + cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + #if CUDA_CP_ASYNC_ACTIVATED + + static_assert(SizeInBytes == 16, + "cp.async only supports CacheOperation::Global when access size is 16B."); + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + int src_in_bytes = (pred_guard ? SizeInBytes : 0); + asm volatile( +#if CUTLASS_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#else + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#endif + "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); + + #else + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + #endif + } +}; + +/// Partial specialization +template <> +struct cp_async_nan<16, CacheOperation::Global> { + static int const kSizeInBytes = 16; + + /// Copy with nan fill + CUTLASS_DEVICE + cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { + #if CUDA_CP_ASYNC_ACTIVATED + + static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, + OOB_NAN_F16x2, OOB_NAN_F16x2}; + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" +#else + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" +#endif + " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" + "}\n" + : + : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), + "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), + "r"(OOB_NAN_F16x8.w)); + + #else + + CUTLASS_UNUSED(smem_ptr); + CUTLASS_UNUSED(global_ptr); + CUTLASS_UNUSED(pred_guard); + CUTLASS_NOT_IMPLEMENTED(); + + #endif + } +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +CUTLASS_DEVICE +void cp_async_fence() { + #if CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.commit_group;\n" ::); + #endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Blocks until all but previous cp.async.commit_group operations have committed. +template +CUTLASS_DEVICE void cp_async_wait() { + #if CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + #endif +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +template <> +CUTLASS_DEVICE void cp_async_wait<0>() { + #if CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.wait_all;\n" ::); + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma.h new file mode 100644 index 00000000..633a0804 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the operation implied by MMA. +struct OpMultiplyAdd {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the result is saturated to MAX_FLOAT|MIN_FLOAT or MAX_INT|MIN_INT +struct OpMultiplyAddSaturate {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the input is converted to a narrower type (BF16) +struct OpMultiplyAddFastBF16 {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the input is converted to a narrower type (F16) +struct OpMultiplyAddFastF16 {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the input data types are mixed and the narrower type is +/// upcasted to the wider type +struct OpMultiplyAddMixedInputUpcast {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the input is converted to 2 (big and small) TF32 components +// Perform 3xTF32 or 4xTF32 for every F32 output element +struct OpMultiplyAddFastF32 {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the input is converted to 2 (big and small) TF32 components +// Perform 3xTF32 or 4xTF32 for every complex output element +struct OpMultiplyAddComplexFastF32 {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating that staged accumulation is not to be used. This is valid only for SM89 +/// FP8 kernels. +struct OpMultiplyAddFastAccum; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the complex multiply-add operation +struct OpMultiplyAddComplex {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the gaussian complex multiply-add operation +struct OpMultiplyAddGaussianComplex {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the inner product is defined by (XOR, POPC) +struct OpXorPopc {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the inner product is defined by (AND, POPC) +struct OpAndPopc {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag classifying math operators as thread-level operations. +struct OpClassSimt {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag classifying operators as Tensor Core operations. +struct OpClassTensorOp {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Tag classifying operators as WMMA Tensor Core operations +struct OpClassWmmaTensorOp {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag classifying operators as Tensor Core with structure sparse operations. +struct OpClassSparseTensorOp {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Number of threads participating + int kThreads_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Inner product operator + typename Operator +> +struct Mma; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation +template < + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Inner product operator + typename Operator_ +> +struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = Operator_; + using ElementC = ElementC_; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + + multiply_add op; + + d[0] = op(a[0], b[0], c[0]); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specifies internal data type for computation +struct SPFormatType { + enum Kind { + Thread + }; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Number of threads participating + int kThreads_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Inner product operator + typename Operator, + /// Specifies meta data format + SPFormatType::Kind SPFormat = SPFormatType::Thread +> +struct SparseMma; + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Specializations for each compute capability +// + +#include "cutlass/arch/mma_sm50.h" +#include "cutlass/arch/mma_sm60.h" +#include "cutlass/arch/mma_sm61.h" +#include "cutlass/arch/mma_sm70.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sparse_sm80.h" +#include "cutlass/arch/mma_sm89.h" +#include "cutlass/arch/mma_sparse_sm89.h" +#include "cutlass/arch/mma_sm90.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { +namespace detail { +/// Helper for determining whether staged accumulation should be used for a given operator +template +struct UseStagedAccumulation { + static bool const value = platform::is_same::value || + platform::is_same::value || + is_sm89_staged_policy_v; +}; +} // namespace detail +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm50.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm50.h new file mode 100644 index 00000000..98ff18be --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm50.h @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/functional.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = float; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + d[0] = a[0] * b[0] + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = double; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + + d[0] = a[0] * b[0] + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = int; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + + d[0] = a[0] * b[0] + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0].real() + c[0].real(); + d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); + d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); + d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + float, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0] + c[0].real(); + d[0].imag() = a[0].imag() * b[0] + c[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + float, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0] * b[0].real() + c[0].real(); + d[0].imag() = a[0] * b[0].imag() + d[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0].real() + c[0].real(); + d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); + d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); + d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); + } +}; + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + double, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0] + c[0].real(); + d[0].imag() = a[0].imag() * b[0] + c[0].imag(); + } +}; + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + double, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0] * b[0].real() + c[0].real(); + d[0].imag() = a[0] * b[0].imag() + d[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = float; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + d[0] = float(a[0]) * float(b[0]) + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation for Quaternions +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, Quaternion, LayoutA, Quaternion, LayoutB, Quaternion, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using Element = Quaternion; + using ElementC = Element; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + multiply_add op; + d[0] = op(a[0], b[0], c[0]); + } + +}; + +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm60.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm60.h new file mode 100644 index 00000000..3e3c71ef --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm60.h @@ -0,0 +1,252 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include + +#include "cutlass/arch/mma.h" + +#include "cutlass/layout/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<2,1,1>, + 1, + half_t, + LayoutA, + half_t, + LayoutB, + half_t, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 B = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 D = __hfma2(A, B, C); + + d = reinterpret_cast &>(D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[0] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<1,2,1>, + 1, + half_t, + LayoutA, + half_t, + LayoutB, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 2, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 B = reinterpret_cast<__half2 const &>(b); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 D = __hfma2(A, B, C); + + d = reinterpret_cast &>(D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[0] * b[i] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template <> +struct Mma < + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b)); + __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b)); + + __half2 const *C = reinterpret_cast<__half2 const *>(&c); + + __half2 Dlo = __hfma2(A, Blo, C[0]); + __half2 Dhi = __hfma2(A, Bhi, C[1]); + + Array * D = reinterpret_cast *>(&d); + + D[0] = reinterpret_cast const &>(Dlo); + D[1] = reinterpret_cast const &>(Dhi); + +#else + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < 2; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j]; + } + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template <> +struct Mma< + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a)); + __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a)); + __half2 const & B = reinterpret_cast<__half2 const &>(b); + + __half2 const *C = reinterpret_cast<__half2 const *>(&c); + + __half2 Dlo = __hfma2(Alo, B, C[0]); + __half2 Dhi = __hfma2(Ahi, B, C[0]); + + Array * D = reinterpret_cast *>(&d); + + D[0] = reinterpret_cast &>(Dlo); + D[1] = reinterpret_cast &>(Dhi); +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < 2; ++j) { + d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j]; + } + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm61.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm61.h new file mode 100644 index 00000000..82a5aa72 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm61.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<1,1,4>, + 1, + int8_t, + LayoutA, + int8_t, + LayoutB, + int, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 4>; + using Operator = OpMultiplyAdd; + using ElementC = int; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) + + unsigned const &A = reinterpret_cast(a); + unsigned const &B = reinterpret_cast(b); + + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(d[0]) + : "r"(A), "r"(B), "r"(c[0])); + +#else + + d[0] = c[0]; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < 4; ++k) { + d[0] += a[k] * b[k]; + } + +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<1, 1, 2>, + 1, + int16_t, + layout::RowMajor, + int16_t, + layout::ColumnMajor, + int, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 2>; + using Operator = OpMultiplyAdd; + using ElementC = int; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) + + unsigned const &A = reinterpret_cast(a); + unsigned const &B = reinterpret_cast(b); + + asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" + : "=r"(d[0]) + : "r"(A), "r"(B), "r"(c[0])); +#else + d[0] = c[0]; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < 2; ++k) { + d[0] += a[k] * b[k]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm70.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm70.h new file mode 100644 index 00000000..2785a3ff --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm70.h @@ -0,0 +1,665 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +#define CUTLASS_ARCH_MMA_SM70_SUPPORTED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) + +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1)) +#define CUTLASS_ARCH_MMA_SM70_ENABLED +#endif + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Matrix multiply accumulate 884 - FP16 accumulation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8,8,4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Matrix multiply accumulate 884 - FP32 accumulation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation specialized for the entire warp +template < + typename LayoutA, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename Operator +> +struct Mma< + gemm::GemmShape<16, 16, 4>, + 32, + half_t, + LayoutA, + half_t, + LayoutB, + ElementC, + LayoutC, + Operator +> : + public Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + LayoutA, + half_t, + LayoutB, + ElementC, + LayoutC, + Operator> { + + using Shape = gemm::GemmShape<16, 16, 4>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm75.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm75.h new file mode 100644 index 00000000..33c35235 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm75.h @@ -0,0 +1,1285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply for SM75 +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/arch/wmma.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +// CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply. +#include +#include "cutlass/wmma_array.h" +#endif + +// CUTLASS includes +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) + +#define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +#define CUTLASS_ARCH_MMA_SM75_ENABLED +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - FP16 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<16, 8, 8>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 8>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Integer matrix multiply .8816 (8b) +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Integer matrix multiply (8b) with SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Integer matrix multiply (4b) +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + int4b_t, + layout::RowMajor, + int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + uint4b_t, + layout::RowMajor, + int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + int4b_t, + layout::RowMajor, + uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + uint4b_t, + layout::RowMajor, + uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Integer matrix multiply (4b) - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + int4b_t, + layout::RowMajor, + int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + uint4b_t, + layout::RowMajor, + int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + int4b_t, + layout::RowMajor, + uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + uint4b_t, + layout::RowMajor, + uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// b1 ^ b1 + s32 => s32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template <> +struct Mma< + gemm::GemmShape<8,8,128>, + 32, + uint1b_t, + layout::RowMajor, + uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpXorPopc> { + + using Shape = gemm::GemmShape<8,8,128>; + + using ElementA = uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpXorPopc; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + using WmmaFragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + nvcuda::wmma::experimental::precision::b1, + nvcuda::wmma::row_major>; + + using WmmaFragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + nvcuda::wmma::experimental::precision::b1, + nvcuda::wmma::col_major>; + + using WmmaFragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + int>; + + WmmaFragmentA const & A = reinterpret_cast(a); + WmmaFragmentB const & B = reinterpret_cast(b); + + WmmaFragmentC const & C = reinterpret_cast(c); + WmmaFragmentC & D = reinterpret_cast(d); + + nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, + nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. + +#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm80.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm80.h new file mode 100644 index 00000000..3b49aad9 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm80.h @@ -0,0 +1,2265 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) + +#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define CUTLASS_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED +#endif +#if (__CUDA_ARCH__ <= 890) +#define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + +#endif + +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - Float BF16, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 8>, + 32, + bfloat16_t, + layout::RowMajor, + bfloat16_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1684 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 4>, + 32, + tfloat32_t, + layout::RowMajor, + tfloat32_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 4>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct Mma, 32, tfloat32_t, layout::RowMajor, + tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd> { + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + bfloat16_t, + layout::RowMajor, + bfloat16_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 884 - F64 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<8,8,4>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8,8,4>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + double const & A = reinterpret_cast(a); + double const & B = reinterpret_cast(b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=d"(D[0]), "=d"(D[1]) + : "d"(A), "d"(B), "d"(C[0]), "d"(C[1])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 - S8 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - S8 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16864 - S4 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int32_t, + layout::RowMajor, + OpAndPopc> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int32_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpAndPopc; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int32_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpXorPopc> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpXorPopc; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); + +#endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm89.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm89.h new file mode 100644 index 00000000..fe4b7eb7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm89.h @@ -0,0 +1,367 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Matrix multiply-accumulate specialzied for SM89 +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) + +# define CUTLASS_ARCH_MMA_SM89_SUPPORTED 1 +#endif + +#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890) +# define CUTLASS_ARCH_MMA_SM89_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Whether the Mma uses as SM89 staged accumulation policy +template +static constexpr bool is_sm89_staged_policy_v = + ( + // ElementA must be FP8 + platform::is_same::value || + platform::is_same::value + ) && + ( + // ElementB must be FP8 + platform::is_same::value || + platform::is_same::value + ) && + ( + // The instruction shape must be 16x8x32 + Operator::ArchMmaOperator::Shape::kM == 16 && + Operator::ArchMmaOperator::Shape::kN == 8 && + Operator::ArchMmaOperator::Shape::kK == 32 + ) && + ( + // The operator must be OpMultiplyAdd (default) + platform::is_same::value + ); +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F32 = fe4m3 * fe4m3 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe4m3 * fe5m2 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe5m2 * fe4m3 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe5m2 * fe5m2 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm90.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm90.h new file mode 100644 index 00000000..b2e76a90 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sm90.h @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED + #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)) + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED + #endif + #endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12) + #define CUTLASS_ARCH_MMA_SM90_SUPPORTED + #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED)) + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + #define CUTLASS_ARCH_MMA_SM90_ENABLED + #endif + #endif +#endif + +#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x4 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,4>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,4>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[1]), + "d"(B[0]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x8 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,8>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,8>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3]) + : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), + "d"(B[0]), "d"(B[1]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x16 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]), + "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sparse_sm80.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sparse_sm80.h new file mode 100644 index 00000000..255df67f --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sparse_sm80.h @@ -0,0 +1,1685 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse matrix multiply accumulate for SM80 +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1)) + +#define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED +#endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16832 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct SparseMma< + gemm::GemmShape<16, 8, 32>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread +> { + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct SparseMma< + gemm::GemmShape<16, 8, 32>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread + > { + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else { + assert(0); + } + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 +template <> +struct SparseMma, 32, bfloat16_t, layout::RowMajor, + bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd, SPFormatType::Thread> { + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16816 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct SparseMma, 32, tfloat32_t, layout::RowMajor, + tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd, SPFormatType::Thread> { + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 4; + + static int const kMaxID2 = 2; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/mma_sparse_sm89.h b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sparse_sm89.h new file mode 100644 index 00000000..c092df76 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/mma_sparse_sm89.h @@ -0,0 +1,409 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse matrix multiply accumulate for SM89 +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) + +# define CUTLASS_ARCH_SPARSE_MMA_SM89_SUPPORTED 1 +#endif + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890) +# define CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe4m3 * fe4m3 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe4m3 * fe5m2 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe5m2 * fe4m3 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe5m2 * fe5m2 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/reg_reconfig.h b/server/punica_kernels/include/cutlass/cutlass/arch/reg_reconfig.h new file mode 100644 index 00000000..f7b12a70 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/reg_reconfig.h @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief PTX for CTA Reconfiguration +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) + #if (defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif +#else + #define CUDA_CTA_RECONFIG_ACTIVATED 0 +#endif + +namespace cutlass { +namespace arch { + +template +CUTLASS_DEVICE +void warpgroup_reg_alloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} +template +CUTLASS_DEVICE +void warpgroup_reg_dealloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/simd.h b/server/punica_kernels/include/cutlass/cutlass/arch/simd.h new file mode 100644 index 00000000..3104746e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/simd.h @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing SIMD operators +*/ + +#pragma once + +#include "../array.h" +#include "../numeric_types.h" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Element-wise operators +// + +CUTLASS_HOST_DEVICE +template +Array operator*(Array const &a, Array const &b) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] * b[i]; + } + return d; +} + +CUTLASS_HOST_DEVICE +template +Array operator+(Array const &a, Array const &b) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] + b[i]; + } + return d; +} + +CUTLASS_HOST_DEVICE +template +Array operator-(Array const &a, Array const &b) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] - b[i]; + } + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Multiply-accumulate operators +// + +CUTLASS_HOST_DEVICE +template +Array mac(Array const &a, Array const &b, Array const &c) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] * b[i] + c[i]; + } + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Dot product operator +// + +CUTLASS_HOST_DEVICE +template +Accumulator dot(Array const &a, Array const &b, Accumulator accum) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + accum += a[i] * b[i]; + } + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "simd_sm60.h" +#include "simd_sm61.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/simd_sm60.h b/server/punica_kernels/include/cutlass/cutlass/arch/simd_sm60.h new file mode 100644 index 00000000..6e1ef204 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/simd_sm60.h @@ -0,0 +1,104 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing SIMD operators for SM60 +*/ + +#pragma once + +#include "simd.h" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Element-wise operators - specialized for half_t x 2 +// + +CUTLASS_HOST_DEVICE +template <> +Array operator*(Array const &a, Array const &b) { + Array d; + + return d; +} + +CUTLASS_HOST_DEVICE +template <> +Array operator+(AArray const &a, Array const &b) { + Array d; + + return d; +} + +CUTLASS_HOST_DEVICE +template <> +Array operator-(Array const &a, Array const &b) { + Array d; + + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Multiply-accumulate operators - specialized for half_t x 2 +CUTLASS_HOST_DEVICE +template <> +Array mac(Array const &a, Array const &b, Array const &c) { + Array d; + + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for half_t <- (half_t * half_t) x 2 + half_t +CUTLASS_HOST_DEVICE +template <> +half_t dot(Array const &a, Array const &b, half_t accum) { + + return accum; +} + +/// Dot product operator - specialized for float <- (half_t * half_t) x 2 + float +CUTLASS_HOST_DEVICE +template <> +float dot(Array const &a, Array const &b, float accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/simd_sm61.h b/server/punica_kernels/include/cutlass/cutlass/arch/simd_sm61.h new file mode 100644 index 00000000..b783c943 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/simd_sm61.h @@ -0,0 +1,147 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing SIMD operators for SM61 +*/ + +#pragma once + +#include "simd.h" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for int32_t <- (int8_t * int8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint8_t * int8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (int8_t * uint8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint8_t * uint8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/wmma.h b/server/punica_kernels/include/cutlass/cutlass/arch/wmma.h new file mode 100644 index 00000000..80cb8939 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/wmma.h @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for warp matrix multiply-add (WMMA) operations +*/ + +#pragma once + +// CUTLASS WMMA does not support clang at present. +#if !(defined(__clang__) && defined(__CUDA__)) + +#if (__CUDACC_VER_MAJOR__ >= 9) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) +#define CUTLASS_ARCH_WMMA_ENABLED +#define CUTLASS_ARCH_WMMA_SM70_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 10) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720)) +#define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED +#define CUTLASS_ARCH_WMMA_SM72_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 10) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) +#define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED +#define CUTLASS_ARCH_WMMA_SM75_ENABLED +#endif +#endif + +#endif //!(defined(__clang__) && defined(__CUDA__)) + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +#include +#include "cutlass/arch/mma.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps cutlass data types => nvcuda::wmma data types +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct CutlassToWmmaDataType{ + using Type = Type_; +}; + +/// Statically maps cutlass::half_t => __half +template<> +struct CutlassToWmmaDataType { + using Type = __half; +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) +template<> +struct CutlassToWmmaDataType { + using Type = __nv_bfloat16; +}; +#endif + +/// Statically maps int8_t => char +template<> +struct CutlassToWmmaDataType { + using Type = signed char; +}; + +/// Statically maps uint8_t => char +template<> +struct CutlassToWmmaDataType { + using Type = unsigned char; +}; + +/// Statically maps int32_t => int +template<> +struct CutlassToWmmaDataType { + using Type = int; +}; + +#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) +/// Statically maps cutlass::int4b_t => experimental::precision::s4 +template<> +struct CutlassToWmmaDataType { + using Type = nvcuda::wmma::experimental::precision::s4; +}; + +/// Statically maps cutlass::uint4b_t => experimental::precision::s4 +template<> +struct CutlassToWmmaDataType { + using Type = nvcuda::wmma::experimental::precision::u4; +}; + +/// Statically maps cutlass::uint1b_t => experimental::precision::b1 +template<> +struct CutlassToWmmaDataType { + using Type = nvcuda::wmma::experimental::precision::b1; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps cutlass::layout => nvcuda::wmma layout tags +//////////////////////////////////////////////////////////////////////////////////////////////// +template +struct CutlassToWmmaLayout { +}; + +/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags +template <> +struct CutlassToWmmaLayout { + using Layout = nvcuda::wmma::row_major; + static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags +//////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct CutlassToWmmaLayout { + using Layout = nvcuda::wmma::col_major; + static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major; +}; +//////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps nvcuda::wmma data types => cutlass data types +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct WmmaToCutlassDataType{ + using Type = Type_; +}; + +/// Statically maps __half => cutlass::half_t +template<> +struct WmmaToCutlassDataType<__half> { + using Type = cutlass::half_t; +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) +template<> +struct WmmaToCutlassDataType<__nv_bfloat16> { + using Type = cutlass::bfloat16_t; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks +// for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]), +// and native wmma size (Shape) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, ///< Size of the matrix product (concept: GemmShape) + typename ElementA_, ///< Data type of A elements + typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout) + typename ElementB_, ///< Data type of B elements + typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout) + typename ElementC_, ///< Element type of C matrix + typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout) + typename Operator_ = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc) +> +struct Wmma; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Specializations for each compute capability +// +#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED +#include "cutlass/arch/wmma_sm70.h" +#endif + +#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED +#include "cutlass/arch/wmma_sm72.h" +#endif + +#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED +#include "cutlass/arch/wmma_sm75.h" +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif //CUTLASS_ARCH_WMMA_ENABLED diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm70.h b/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm70.h new file mode 100644 index 00000000..74605a36 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm70.h @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif +#include "cutlass/layout/matrix.h" + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace arch { + + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for half +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename ElementC_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + cutlass::half_t, ///< ElementA + LayoutA_, ///< LayoutA + cutlass::half_t, ///< ElementB + LayoutB_, ///< LayoutB + ElementC_, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { + +#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) + using Shape = Shape_; + using ElementA = cutlass::half_t; + using LayoutA = LayoutA_; + using ElementB = cutlass::half_t; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm70; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value || + platform::is_same, Shape>::value || + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); + + // check supported wmma output data type for the given multiplicand data types + static_assert( + platform::is_same::value || platform::is_same::value, + "Supported of wmma output data type for f16 multiplicands are: f16 and f32"); + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + + nvcuda::wmma::mma_sync(D, A, B, C); + } +#else + static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond"); +#endif + +}; + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm72.h b/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm72.h new file mode 100644 index 00000000..a2a7dc27 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm72.h @@ -0,0 +1,210 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif +#include "cutlass/layout/matrix.h" + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for int8_t +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + int8_t, ///< ElementA + LayoutA_, ///< LayoutA + int8_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) + using Shape = Shape_; + using ElementA = int8_t; + using LayoutA = LayoutA_; + using ElementB = int8_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value || + platform::is_same, Shape>::value || + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); + + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + + nvcuda::wmma::mma_sync(D, A, B, C); + } + +#else + static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); +#endif + +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for uint8_t +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + uint8_t, ///< ElementA + LayoutA_, ///< LayoutA + uint8_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) + using Shape = Shape_; + using ElementA = uint8_t; + using LayoutA = LayoutA_; + using ElementB = uint8_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value || + platform::is_same, Shape>::value || + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + + nvcuda::wmma::mma_sync(D, A, B, C); + } + +#else + static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); +#endif + +}; + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm75.h b/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm75.h new file mode 100644 index 00000000..80db4417 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/arch/wmma_sm75.h @@ -0,0 +1,207 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif +#include "cutlass/layout/matrix.h" + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4). +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + cutlass::int4b_t, ///< ElementA + LayoutA_, ///< LayoutA + cutlass::int4b_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) + using Shape = Shape_; + using ElementA = cutlass::int4b_t; + using LayoutA = LayoutA_; + using ElementB = cutlass::int4b_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm75; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32"); + + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + nvcuda::wmma::mma_sync(D, A, B, C); + + } + +#else + static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); +#endif + +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1). +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + cutlass::uint1b_t, ///< ElementA + LayoutA_, ///< LayoutA + cutlass::uint1b_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpXorPopc ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) + using Shape = Shape_; + using ElementA = cutlass::uint1b_t; + using LayoutA = LayoutA_; + using ElementB = cutlass::uint1b_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpXorPopc; + using ArchTag = arch::Sm75; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128"); + + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, + nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); + } + +#else + static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); +#endif + +}; + +} // namespace arch +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/array.h b/server/punica_kernels/include/cutlass/cutlass/array.h new file mode 100644 index 00000000..a54a8d31 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/array.h @@ -0,0 +1,2635 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ + +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_types.h" +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically sized array for any data type +template < + typename T, + int N, + bool RegisterSized = sizeof_bits::value >= 32 +> +class Array; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines the size of an Array<> in bits +template +struct sizeof_bits > { + static constexpr int value = sizeof(Array) * 8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if the argument is a power of 2 +CUTLASS_HOST_DEVICE +constexpr bool ispow2(unsigned x) { + return x && (!(x & (x - 1))); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the largest power of two not greater than the argument. +CUTLASS_HOST_DEVICE +constexpr unsigned floor_pow_2(unsigned x) { + return (x == 0 || ispow2(x)) ? x : ((floor_pow_2(x >> 1)) << 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically sized array for any data type +template < + typename T, + int N +> +class Array { +public: + + /// Storage type + using Storage = T; + + /// Element type + using Element = T; + + /// Number of storage elements + //static std::size_t const kStorageElements = N; + static size_t const kStorageElements = N; + + /// Number of logical elements + static size_t const kElements = N; + + // + // C++ standard members + // + + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type &reference; + typedef value_type const & const_reference; + typedef value_type *pointer; + typedef value_type const * const_pointer; + + // + // Iterators + // + + /// Bidirectional iterator over elements + class iterator { + + /// Pointer to object + T *ptr_; + + public: + + CUTLASS_HOST_DEVICE + iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + iterator(T *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + iterator &operator++() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + iterator &operator--() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + iterator operator++(int) { + iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + iterator operator--(int) { + iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T &operator*() const { + return *ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator==(iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + + /// Bidirectional constant iterator over elements + class const_iterator { + + /// Pointer to object + const T *ptr_; + + public: + + CUTLASS_HOST_DEVICE + const_iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + const_iterator(T const *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + const_iterator &operator++() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_iterator &operator--() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_iterator operator++(int) { + const_iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + const_iterator operator--(int) { + const_iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T const &operator*() const { + return *ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator==(const_iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(const_iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + + /// Bidirectional iterator over elements + class reverse_iterator { + + /// Pointer to object + T *ptr_; + + public: + + CUTLASS_HOST_DEVICE + reverse_iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + reverse_iterator(T *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + reverse_iterator &operator++() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + reverse_iterator &operator--() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + reverse_iterator operator++(int) { + iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + reverse_iterator operator--(int) { + iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T &operator*() const { + return *(ptr_ - 1); + } + + CUTLASS_HOST_DEVICE + bool operator==(reverse_iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(reverse_iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + + /// Bidirectional constant iterator over elements + class const_reverse_iterator { + + /// Pointer to object + T const *ptr_; + + public: + + CUTLASS_HOST_DEVICE + const_reverse_iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + const_reverse_iterator(T const *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + const_reverse_iterator &operator++() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator &operator--() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator operator++(int) { + const_reverse_iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator operator--(int) { + const_reverse_iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T const &operator*() const { + return *(ptr_ - 1); + } + + CUTLASS_HOST_DEVICE + bool operator==(const_iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(const_iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + +private: + + /// Internal storage + Storage storage[kElements]; + +public: + + #if 0 + CUTLASS_HOST_DEVICE + Array() { } + + CUTLASS_HOST_DEVICE + Array(Array const &x) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElements; ++i) { + storage[i] = x.storage[i]; + } + } + #endif + + /// Efficient clear method + CUTLASS_HOST_DEVICE + void clear() { + fill(T(0)); + } + + CUTLASS_HOST_DEVICE + reference at(size_type pos) { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + const_reference at(size_type pos) const { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + reference operator[](size_type pos) { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + const_reference operator[](size_type pos) const { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + reference front() { + return reinterpret_cast(storage[0]); + } + + CUTLASS_HOST_DEVICE + const_reference front() const { + return reinterpret_cast(storage[0]); + } + + CUTLASS_HOST_DEVICE + reference back() { + return reinterpret_cast(storage[kStorageElements - 1]); + } + + CUTLASS_HOST_DEVICE + const_reference back() const { + return reinterpret_cast(storage[kStorageElements - 1]); + } + + CUTLASS_HOST_DEVICE + pointer data() { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + const_pointer data() const { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + pointer raw_data() { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + const_pointer raw_data() const { + return reinterpret_cast(storage); + } + + + CUTLASS_HOST_DEVICE + constexpr bool empty() const { + return !kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type max_size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + void fill(T const &value) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(kElements); ++i) { + storage[i] = static_cast(value); + } + } + + CUTLASS_HOST_DEVICE + iterator begin() { + return iterator(storage); + } + + CUTLASS_HOST_DEVICE + const_iterator begin() const { + return cbegin(); + } + + CUTLASS_HOST_DEVICE + const_iterator cbegin() const { + return const_iterator(storage); + } + + CUTLASS_HOST_DEVICE + iterator end() { + return iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + const_iterator end() const { + return cend(); + } + + CUTLASS_HOST_DEVICE + const_iterator cend() const { + return const_iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rbegin() { + return reverse_iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator rbegin() const { + return crbegin(); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crbegin() const { + return const_reverse_iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rend() { + return reverse_iterator(reinterpret_cast(storage)); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator rend() const { + return crend(); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crend() const { + return const_reverse_iterator(reinterpret_cast(storage)); + } + + // + // Comparison operators + // + +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Factories +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x) { + Array m; + m[0] = x; + return m; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y) { + Array m; + m[0] = x; + m[1] = y; + return m; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y, Element z) { + Array m; + m[0] = x; + m[1] = y; + m[2] = z; + return m; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y, Element z, Element w) { + Array m; + m[0] = x; + m[1] = y; + m[2] = z; + m[3] = w; + return m; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct absolute_value_op< Array > { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + absolute_value_op scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; +template +struct minus> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct multiplies> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct maximum_absolute_value_reduction, PropogateNaN> { + + CUTLASS_HOST_DEVICE + T operator() (T const& scalar, Array const& rhs) const { + + T result = scalar; + maximum_absolute_value_reduction scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result = scalar_op(result, rhs[i]); + } + + return result; + } +}; + +template +struct scale> { + T const scaling_factor_; + + CUTLASS_HOST_DEVICE + scale(T scaling_factor) : scaling_factor_(scaling_factor) { + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & rhs) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = rhs[i] * scaling_factor_; + } + + return result; + } +}; + +template +struct divides> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct reciprocal_approximate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + reciprocal_approximate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct maximum, false> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct maximum, true> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum, false> { + + CUTLASS_HOST_DEVICE + static T scalar_op(T const &lhs, T const &rhs) { + return (rhs < lhs ? rhs : lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum, true> { + + CUTLASS_HOST_DEVICE + static T scalar_op(T const &lhs, T const &rhs) { + return (rhs < lhs ? rhs : lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct negate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + negate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + +/// Fused square-and-plus +template +struct square_and_plus> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + multiply_add, Array, Array> ma_op; + return ma_op(rhs, rhs, lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &rhs) const { + plus> plus_op; + multiplies multiplies_op; + return plus_op(multiplies_op(rhs, rhs), lhs); + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); + } + + return result; + } +}; + + +template +struct conjugate > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + + conjugate conj_op; + + Array ca; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + ca[i] = conj_op(a[i]); + } + return ca; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations targeting SIMD instructions in device code. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs; + } + #endif + + return result; + } +}; + +template +struct minus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs; + } + #endif + + return result; + } +}; + +template +struct multiplies> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmul( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmul( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs; + } + #endif + + return result; + } +}; + +template +struct divides> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs; + } + #endif + + return result; + } +}; + +template +struct negate> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hneg2(source_ptr[i]); + } + + if constexpr (N % 2) { + half_t x = -lhs[N - 1]; + __half lhs_val = reinterpret_cast<__half const &>(x); + result[N - 1] = reinterpret_cast(lhs_val); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = -lhs[i]; + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma_relu( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a, b[i], c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b, c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c), half_t(0)); + } + #endif + + return result; + } +}; + +template +struct minimum, false> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmin( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmin( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs[i] < lhs ? rhs[i] : lhs); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmin( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs < lhs[i] ? rhs : lhs[i]); + } + #endif + + return result; + } +}; + +template +struct maximum, false> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmax( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmax( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs < rhs[i] ? rhs[i] : lhs); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmax( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs[i] < rhs ? rhs : lhs[i]); + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + bfloat16_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned a_packed = static_cast(a.raw()); + a_packed = (a_packed | (a_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + bfloat16_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned b_packed = static_cast(b.raw()); + b_packed = (b_packed | (b_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + bfloat16_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + + unsigned c_packed = static_cast(c.raw()); + c_packed = (c_packed | (c_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } +}; + + +/// bit_and +template +struct bit_and> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] & b_data[i]); + } + + return result; + } +}; + + +/// bit_or +template +struct bit_or> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] | b_data[i]); + } + + return result; + } +}; + + +/// bit_not +template +struct bit_not> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (~a_data[i]); + } + + return result; + } +}; + + +/// bit_xor +template +struct bit_xor> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] ^ b_data[i]); + } + + return result; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Operator overloads +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs, Array const &rhs) { + minus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs) { + negate> op; + return op(lhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(T lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, T rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator/(Array const &lhs, Array const &rhs) { + divides> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(T a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, T b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, T c) { + multiply_add> op; + return op(a, b, c); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/array_subbyte.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// AlignedArray +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Aligned array type +template < + /// Element type + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = ( sizeof_bits::value * N + 7 ) / 8 +> +class alignas(Alignment) AlignedArray: public Array { +public: + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/array_planar_complex.h b/server/punica_kernels/include/cutlass/cutlass/array_planar_complex.h new file mode 100644 index 00000000..85268cac --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/array_planar_complex.h @@ -0,0 +1,103 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Array holding planar complex elements +template +struct ArrayPlanarComplex { + + /// Underlying real element + using Element = Element_; + + /// Number of logical elements + static size_t const kElements = N; + + /// Underlying Fragment of real-valued elemenets + using ArrayReal = Array; + +public: + + /// Fragment of real-valued elements representing the real part + ArrayReal real; + + /// Fragment of real-valued elements representing the imaginary part + ArrayReal imag; + +public: + + /// Ctor + CUTLASS_HOST_DEVICE + ArrayPlanarComplex() { } + + /// Ctor + CUTLASS_HOST_DEVICE + ArrayPlanarComplex( + ArrayReal const &real_, + ArrayReal const &imag_ + ): + real(real_), imag(imag_) { } + + /// Sets the array to zero efficiently + CUTLASS_HOST_DEVICE + void clear() { + real.clear(); + imag.clear(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to deduce template arguments +template +CUTLASS_HOST_DEVICE +ArrayPlanarComplex +make_ArrayPlanarComplex(Array const &real, Array const &imag) { + return ArrayPlanarComplex(real, imag); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/array_subbyte.h b/server/punica_kernels/include/cutlass/cutlass/array_subbyte.h new file mode 100644 index 00000000..04232ea4 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/array_subbyte.h @@ -0,0 +1,573 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically sized array for any data type +template < + typename T, + int N +> +class Array { +public: + + static int const kSizeBits = sizeof_bits::value * N; + + /// Storage type + using Storage = typename platform::conditional< + ((kSizeBits % 32) != 0), + typename platform::conditional< + ((kSizeBits % 16) != 0), + uint8_t, + uint16_t + >::type, + uint32_t + >::type; + + /// Element type + using Element = T; + + /// Number of logical elements per stored object + static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; + + /// Number of storage elements + static size_t const kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; + + /// Number of logical elements + static size_t const kElements = N; + + /// Bitmask for covering one item + static Storage const kMask = ((Storage(1) << sizeof_bits::value) - 1); + + // + // C++ standard members with pointer types removed + // + + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type *pointer; + typedef value_type const *const_pointer; + + // + // References + // + + /// Reference object inserts or extracts sub-byte items + class reference { + /// Pointer to storage element + Storage *ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + reference(): ptr_(nullptr), idx_(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + /// Assignment + CUTLASS_HOST_DEVICE + reference &operator=(T x) { + Storage item = (reinterpret_cast(x) & kMask); + + Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); + *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); + + return *this; + } + + CUTLASS_HOST_DEVICE + T get() const { + Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Extract + CUTLASS_HOST_DEVICE + operator T() const { + return get(); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + }; + + /// Reference object extracts sub-byte items + class const_reference { + + /// Pointer to storage element + Storage const *ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + const_reference(): ptr_(nullptr), idx_(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTLASS_HOST_DEVICE + const T get() const { + Storage item = (*ptr_ >> (idx_ * sizeof_bits::value)) & kMask; + return reinterpret_cast(item); + } + + /// Extract + CUTLASS_HOST_DEVICE + operator T() const { + Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + }; + + // + // Iterators + // + + /// Bidirectional iterator over elements + class iterator { + + /// Pointer to storage element + Storage *ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + CUTLASS_HOST_DEVICE + iterator(): ptr_(nullptr), idx_(0) { } + + CUTLASS_HOST_DEVICE + iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTLASS_HOST_DEVICE + iterator &operator++() { + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator &operator--() { + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator operator++(int) { + iterator ret(*this); + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return ret; + } + + CUTLASS_HOST_DEVICE + iterator operator--(int) { + iterator ret(*this); + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return ret; + } + + CUTLASS_HOST_DEVICE + reference operator*() const { + return reference(ptr_, idx_); + } + + CUTLASS_HOST_DEVICE + bool operator==(iterator const &other) const { + return ptr_ == other.ptr_ && idx_ == other.idx_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(iterator const &other) const { + return !(*this == other); + } + }; + + /// Bidirectional constant iterator over elements + class const_iterator { + + /// Pointer to storage element + Storage const *ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + CUTLASS_HOST_DEVICE + const_iterator(): ptr_(nullptr), idx_(0) { } + + CUTLASS_HOST_DEVICE + const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTLASS_HOST_DEVICE + iterator &operator++() { + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator &operator--() { + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator operator++(int) { + iterator ret(*this); + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return ret; + } + + CUTLASS_HOST_DEVICE + iterator operator--(int) { + iterator ret(*this); + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return ret; + } + + CUTLASS_HOST_DEVICE + const_reference operator*() const { + return const_reference(ptr_, idx_); + } + + CUTLASS_HOST_DEVICE + bool operator==(iterator const &other) const { + return ptr_ == other.ptr_ && idx_ == other.idx_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(iterator const &other) const { + return !(*this == other); + } + }; + + /// Bidirectional iterator over elements + class reverse_iterator { + + /// Pointer to storage element + Storage *ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + CUTLASS_HOST_DEVICE + reverse_iterator(): ptr_(nullptr), idx_(0) { } + + CUTLASS_HOST_DEVICE + reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + }; + + /// Bidirectional constant iterator over elements + class const_reverse_iterator { + + /// Pointer to storage element + Storage const *ptr_; + + /// Index into elements packed into Storage object + int idx_; + + public: + + CUTLASS_HOST_DEVICE + const_reverse_iterator(): ptr_(nullptr), idx_(0) { } + + CUTLASS_HOST_DEVICE + const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + }; + +private: + + /// Internal storage + Storage storage[kStorageElements] = {Storage{0}}; + +public: + + #if 0 + CUTLASS_HOST_DEVICE + Array() { } + + CUTLASS_HOST_DEVICE + Array(Array const &x) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(kStorageElements); ++i) { + storage[i] = x.storage[i]; + } + } + #endif + + /// Efficient clear method + CUTLASS_HOST_DEVICE + void clear() { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(kStorageElements); ++i) { + storage[i] = Storage(0); + } + } + + CUTLASS_HOST_DEVICE + reference at(size_type pos) { + return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + } + + CUTLASS_HOST_DEVICE + const_reference at(size_type pos) const { + return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + } + + CUTLASS_HOST_DEVICE + reference operator[](size_type pos) { + return at(pos); + } + + CUTLASS_HOST_DEVICE + const_reference operator[](size_type pos) const { + return at(pos); + } + + CUTLASS_HOST_DEVICE + reference front() { + return at(0); + } + + CUTLASS_HOST_DEVICE + const_reference front() const { + return at(0); + } + + CUTLASS_HOST_DEVICE + reference back() { + return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + } + + CUTLASS_HOST_DEVICE + const_reference back() const { + return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + } + + CUTLASS_HOST_DEVICE + pointer data() { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + const_pointer data() const { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + Storage * raw_data() { + return storage; + } + + CUTLASS_HOST_DEVICE + Storage const * raw_data() const { + return storage; + } + + + CUTLASS_HOST_DEVICE + constexpr bool empty() const { + return !kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type max_size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + void fill(T const &value) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerStoredItem; ++i) { + reference ref(storage, i); + ref = value; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kStorageElements; ++i) { + storage[i] = storage[0]; + } + } + + CUTLASS_HOST_DEVICE + iterator begin() { + return iterator(storage); + } + + CUTLASS_HOST_DEVICE + const_iterator cbegin() const { + return const_iterator(storage); + } + + CUTLASS_HOST_DEVICE + iterator end() { + return iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + const_iterator cend() const { + return const_iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rbegin() { + return reverse_iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crbegin() const { + return const_reverse_iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rend() { + return reverse_iterator(storage); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crend() const { + return const_reverse_iterator(storage); + } + + // + // Comparison operators + // + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/barrier.h b/server/punica_kernels/include/cutlass/cutlass/barrier.h new file mode 100644 index 00000000..94f300ad --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/barrier.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implementation of a CTA-wide barrier for inter-CTA synchronization. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// +// Utilities for abstracting synchronization methods for barriers +// + +struct SyncthreadsSync { + CUTLASS_DEVICE + static void sync() { + __syncthreads(); + } +}; + +struct SyncwarpSync { + CUTLASS_DEVICE + static void sync() { + __syncwarp(); + } +}; + +template < + int ThreadCount, + int BarrierId +> +struct NamedBarrierSync { + CUTLASS_DEVICE + static void sync() { + cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast(BarrierId)); + } +}; + +} // namepspace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Group or CTA-wide semaphore for inter-CTA synchronization. +template +struct GenericBarrier { + +public: + + /// Flag type + using T = int; + + /// Initial flag value + static const T INIT = 0; + + +protected: + + /// Load flag, as a strong acquire operation (int specialization) + CUTLASS_DEVICE + static int ld_acquire(int *ptr) + { + int state = 0; + +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + + // Acquire pattern using acquire modifier + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); + +#else + asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#endif // (__CUDA_ARCH__ >= 700) + + return state; + } + + + /// Reduce into flag, with release pattern (int specialization) + CUTLASS_DEVICE + static void red_release(int *ptr, int val) + { +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + + // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data + // that was weakly-written by other threads prior to the last syncthreads) + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); + +#else + __threadfence(); + atomicAdd(ptr, val); +#endif // (__CUDA_ARCH__ >= 700) + } + + +public: + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) + { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_acquire(flag_ptr) < count) {} + } + + Sync::sync(); + } + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) + { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_acquire(flag_ptr) != val) {} + } + Sync::sync(); + } + + /// Uses thread[0] to wait for the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(atomicCAS(flag_ptr, val, 0) != val) {} + } + + Sync::sync(); + } + + /// Increment the arrival count for a flag + CUTLASS_DEVICE + static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx, int val = 1) + { + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + Sync::sync(); + + if (thread_idx == 0) + { + red_release(flag_ptr, val); + } + } + + + /// Increment the arrival counts for a range of flags + CUTLASS_DEVICE + static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) + { + int flag_idx = first_flag_idx + thread_idx; + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + // Barrier to make sure all other threads in group have written their data + Sync::sync(); + + // Select threads increment their flags + if (thread_idx < count) { + red_release(flag_ptr, val); + } + } +}; + +using Barrier = GenericBarrier; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Structure for managing multiple NamedBarriers to be used by different warp groups, allowing + * runtime index values to be used to call into named barriers with compile-time-constant IDs. + * + * @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID + * @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into + * @param MaxNumNamedBarriers The maximum number of unique barrier IDs that will be requested on this type +**/ +template < + uint32_t ThreadCount_, + uint32_t Offset = 0, + uint32_t MaxNumNamedBarriers = 16 +> +struct NamedBarrierManager { + + static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers); + static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + template + using BarrierSync = cutlass::GenericBarrier>; + + // Underlying type used by all barriers for synchronization. Does not depend on + // template parameter BarrierId, so passing in 0 suffices. + using T = typename BarrierSync<0>::T; + + using IntegerSequence = cute::make_integer_sequence; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count) { + wait_lt_helper(idx, lock_ptr, thread_idx, flag_idx, count, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + arrive_inc_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + arrive_range_inc_helper(idx, lock_ptr, thread_idx, first_flag_idx, count, val, IntegerSequence{}); + } + +private: + CUTLASS_DEVICE + static void + check_barrier_in_range([[maybe_unused]] uint32_t idx) { + assert((idx >= MaxNumNamedBarriers) && "Index exceeds barrier count"); + } + + template + CUTLASS_DEVICE + static void + wait_lt_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + wait_eq_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val, cute::integer_sequence) { + check_barrier_in_range(idx); + if constexpr (Reset) { + ((Idx == idx && (BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + else { + ((Idx == idx && (BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + } + + template + CUTLASS_DEVICE + static void + arrive_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + arrive_range_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val), true)) || ...); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads) + * via an API that mirrors that of NamedBarrierManager + * + * @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization +**/ +template < + class Synchronizer, + uint32_t ThreadCount_ +> +struct SyncManager { + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + using BarrierSync = cutlass::GenericBarrier; + + // Underlying type used by all barriers for synchronization. + using T = typename BarrierSync::T; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) { + BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/bfloat16.h b/server/punica_kernels/include/cutlass/cutlass/bfloat16.h new file mode 100644 index 00000000..05f46197 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/bfloat16.h @@ -0,0 +1,527 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a proxy class for storing non-standard 16-bit floating point values with + 8 bits of exponent and 7 bit of mantissa. +*/ + +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#else +#include +#include +#include +#include +#endif + +#include +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Floating-point type with 8 bits of exponent and 7 bits of mantissa. +struct alignas(2) bfloat16_t { + + // + // Data members + // + + /// Storage type + uint16_t storage; + + // + // Methods + // + + /// Constructs from an unsigned short + CUTLASS_HOST_DEVICE + static bfloat16_t bitcast(uint16_t x) { + bfloat16_t h; + h.storage = x; + return h; + } + +private: + struct from_32_bit_integer_t {}; + static constexpr from_32_bit_integer_t from_32_bit_integer{}; + + template + CUTLASS_HOST_DEVICE + explicit bfloat16_t(from_32_bit_integer_t, T x) { + static_assert(cutlass::platform::is_integral::value && sizeof(T) == 4, "Requires 32-bit integer"); + + float flt = static_cast(x); + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(flt); + #else + std::memcpy(&bits, &flt, sizeof(bits)); + #endif + + storage = uint16_t(bits >> 16); + } + +public: + /// Default constructor + bfloat16_t() = default; + + /// Floating-point conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(float x) { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) + + asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); + + #else + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(x); + #else + std::memcpy(&bits, &x, sizeof(bits)); + #endif + + if ((bits & 0x7f800000) != 0x7f800000) { + + bool mantissa_bit = ((bits & (1 << 16)) != 0); + bool round_bit = ((bits & (1 << 15)) != 0); + bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0); + + if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { + bits += uint32_t(1 << 16); + } + } + else if (bits & ~0xff800000) { + bits = 0x7fffffff; + } + + storage = uint16_t((bits >> 16) & 0xffff); + #endif + } + + /// Floating-point conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(double x): bfloat16_t(float(x)) { + + } + + /// Integer conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {} + + CUTLASS_HOST_DEVICE + explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {} + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + unsigned bits = (unsigned(storage) << 16); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(bits); + #else + float flt; + std::memcpy(&flt, &bits, sizeof(flt)); + return flt; + #endif + } + + /// Converts to float + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(float(*this)); + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + return (float(*this) != 0.0f); + } + + /// Obtains raw bits + CUTLASS_HOST_DEVICE + uint16_t raw() const { + return storage; + } + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((raw() & 0x8000) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((raw() >> 7) & 0x0ff); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 127; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(raw() & 0x7f); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool signbit(cutlass::bfloat16_t const& h) { + return h.signbit(); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) { + return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fff); +} + +CUTLASS_HOST_DEVICE +bool isnan(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() == 0x0ff) && h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isfinite(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() != 0x0ff); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t nan_bf16(const char*) { + // NVIDIA canonical NaN + return cutlass::bfloat16_t::bitcast(0x7fff); +} + +CUTLASS_HOST_DEVICE +bool isinf(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() == 0x0ff) && !h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isnormal(cutlass::bfloat16_t const& h) { + return h.exponent_biased() && h.exponent_biased() != 0x0ff; +} + +CUTLASS_HOST_DEVICE +int fpclassify(cutlass::bfloat16_t const& h) { + int exp = h.exponent_biased(); + int mantissa = h.mantissa(); + if (exp == 0x0ff) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) { +#if defined(__CUDACC_RTC__) + return cutlass::bfloat16_t(sqrtf(float(h))); +#else + return cutlass::bfloat16_t(std::sqrt(float(h))); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { + + uint16_t a_bits; + uint16_t b_bits; + + #if defined(__CUDA_ARCH__) + a_bits = reinterpret_cast(a); + b_bits = reinterpret_cast(b); + #else + std::memcpy(&a_bits, &a, sizeof(a_bits)); + std::memcpy(&b_bits, &b, sizeof(b_bits)); + #endif + + uint16_t a_mag = (a_bits & 0x7fff); + uint16_t b_sign = (b_bits & 0x8000); + uint16_t result = (a_mag | b_sign); + + return bfloat16_t::bitcast(result); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace std { + +#if !defined(__CUDACC_RTC__) +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 7; + + /// Least positive value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } + + /// Minimum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } +}; +#endif + +} // namespace std + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) + float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator-(bfloat16_t const& lhs) { + return bfloat16_t(-float(lhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { + return bfloat16_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { + lhs = bfloat16_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator++(bfloat16_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = bfloat16_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator--(bfloat16_t & lhs) { + float tmp(lhs); + --tmp; + lhs = bfloat16_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator++(bfloat16_t & lhs, int) { + bfloat16_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = bfloat16_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator--(bfloat16_t & lhs, int) { + bfloat16_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = bfloat16_t(tmp); + return ret; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t operator "" _bf16(long double x) { + return cutlass::bfloat16_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { + return cutlass::bfloat16_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/blas3.h b/server/punica_kernels/include/cutlass/cutlass/blas3.h new file mode 100644 index 00000000..0697a87c --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/blas3.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic include for CUTLASS BLAS3/HPC code. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/blas3_types.h" +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines FillMode inversions +template +struct InvertFillMode; + +/// Invert FillMode lower to upper +template <> +struct InvertFillMode { + static FillMode const mode = FillMode::kUpper; +}; + +/// Invert FillMode upper to lower +template <> +struct InvertFillMode { + static FillMode const mode = FillMode::kLower; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines SideMode inversions +template +struct InvertSideMode; + +/// Invert SideMode left to right +template <> +struct InvertSideMode { + static SideMode const mode = SideMode::kRight; +}; + +/// Invert SideMode right to left +template <> +struct InvertSideMode { + static SideMode const mode = SideMode::kLeft; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines correct compare operation for Triangular matrix boundary +template +struct TrMatrixCompareOp { + using Index = int32_t; + using Type = typename platform::conditional< + (kFillMode == FillMode::kLower), + greater_equal, + less_equal>::type; +}; + +template +struct TrMatrixCompareOp { + using Index = int32_t; + using Type = typename platform::conditional< + (kFillMode == FillMode::kLower), + greater_equal, + less_equal>::type; +}; + +template +struct TrMatrixCompareOp { + using Index = int32_t; + using Type = typename platform::conditional< + (kFillMode == FillMode::kLower), + greater, + less>::type; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Returns precision in terms of bits (based on datatype) to fill tensors with. +// Defaults to 5 bits of mantissa for TF32 and FP32 (with implicit round-offs). +// Also defines acceptable mantissa result variance/error. +template +struct MantissaInBits { + static int constexpr bits = 5; + static double constexpr error = 1.0e-7; +}; + +// Full precision is supported for FP64 +template <> +struct MantissaInBits { + static int constexpr bits = 30; + static double constexpr error = 1.0e-15; +}; + +template <> +struct MantissaInBits> { + static int constexpr bits = 30; + static double constexpr error = 1.0e-15; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/blas3_types.h b/server/punica_kernels/include/cutlass/cutlass/blas3_types.h new file mode 100644 index 00000000..f9e31846 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/blas3_types.h @@ -0,0 +1,78 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing the type of kernel (based on input or output matrices). +enum class BlasMode { + kGemm, + kSymmetric, + kHermitian, + kTriangular, + kInvalid +}; + +/// Enumerated type describing the fill mode for matrices for BLAS functions. +enum class FillMode { + kFull, /// The entire tensor is covered. + kLower, /// The 'lower' part of a tensor is covered including diagonal + kUpper, /// The 'upper' part of a tensor is covered including diaognal + kDiagonal, /// Only diagonal elements are covered. + kNone, /// No element is covered. + kInvalid +}; + +/// Enumerated type describing the diagonal property of matrices for BLAS functions. +enum class DiagType { + kNonUnit, + kUnit, + kZero, // Only used internally for computing SYMM/HEMM + kInvalid +}; + +/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. +enum class SideMode { + kLeft, + kRight, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/block_striped.h b/server/punica_kernels/include/cutlass/cutlass/block_striped.h new file mode 100644 index 00000000..aaf3a243 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/block_striped.h @@ -0,0 +1,266 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable, + statically-sized array types to global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/wmma_array.h" +#include "cutlass/functional.h" +#include "cutlass/complex.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// AccessWidth +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit +template < + typename T, + int Limit> +struct AccessWidth +{ + // Inductive case + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes, /// Template induction variable + bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes + ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))> + struct Detail + { + static const int value = Detail::value; + }; + + // Base case (ObjectBytes is not an even multiple of AlignBytes) + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes> /// Template induction variable + struct Detail + { + static const int value = AlignBytes / 2; + }; + + /// The maximal power-of-two that evenly divides the size of T + static const int value = Detail< + (int) sizeof(T), + 1>::value; +}; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// StripedAccessType +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Default specialization. Striping granularity is type T.) +template < + typename T, /// Data type + int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures) + AccessWidth::value> +struct alignas(TransferBytes) StripedAccessType : public T +{}; + + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::Array. Striping granularity is a multiple of T.) +template < + typename T, /// Array element type + int N, /// Number of elements in array + bool RegisterSized, /// T is register-sized + int TransferBytes> /// Data access width +struct StripedAccessType< + Array, + TransferBytes> +: public AlignedArray< + T, // Element type of StripedAccessType + __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType + TransferBytes> // Alignment of StripedAccessType +{}; + + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::WmmaFragmentArray. Striping granularity is a multiple of T.) +template< + typename Use, + int m, + int n, + int k, + typename ElementT, + typename Layout, + int kFragments, + int TransferBytes> +struct StripedAccessType< + WmmaFragmentArray, kFragments>, + TransferBytes> +: public AlignedArray< + ElementT, + __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)), + TransferBytes> +{}; + +#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStriped +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Utility for performing block-striped access (load, store) of trivially-copyable, +/// statically-sized array types to global memory +template < + int BlockThreads, + typename ArrayT, + typename AccessT = StripedAccessType > +struct BlockStriped +{ + /// Number of striped accesses + static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT)); + static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type"); + + /// Load + CUTLASS_DEVICE + static void load(ArrayT &data, ArrayT *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_data[i] = access_input[(BlockThreads * i) + thread_idx]; + } + } + + /// Load & Add + CUTLASS_DEVICE + static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + plus add; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) + { + access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]); + } + } + + /// Store + CUTLASS_DEVICE + static void store(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + AccessT *access_output = reinterpret_cast(ptr); + const AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_output[(BlockThreads * i) + thread_idx] = access_data[i]; + } + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStripedReduce +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Default specialization) +template < + int BlockThreads, + typename ArrayT, + typename ElementT = typename StripedAccessType::Element> +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + ElementT> +{ + /// Reduce + CUTLASS_DEVICE + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + cutlass::atomic_add reduce; + ElementT *access_output = reinterpret_cast(ptr); + const ElementT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Specialization for half_t. Uses half2 vectorized-reduction.) +template < + int BlockThreads, + typename ArrayT> +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + half2> +{ + static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); + + /// Reduce + CUTLASS_DEVICE + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + cutlass::atomic_add reduce; + half2 *access_output = reinterpret_cast(ptr); + const half2 *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) + { + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/cluster_launch.hpp b/server/punica_kernels/include/cutlass/cutlass/cluster_launch.hpp new file mode 100644 index 00000000..b5944de9 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/cluster_launch.hpp @@ -0,0 +1,268 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief PTX for TMA Tensor Memory Access operators on memory added for SM90 +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" +#if defined(__CUDACC_RTC__) +#include +#else +#include +#include +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED +#endif + +namespace cutlass { + +#ifndef NDEBUG +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + fprintf(stderr, \ + "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(cudaError_t_status)); \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#else +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#endif + +struct ClusterLauncher { + constexpr static int MaxClusterSize = 32; + + // Check for hardware compatibility + static inline CUTLASS_HOST + Status check_cluster_dims(dim3 grid, dim3 cluster) { + if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && + (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch."); + return Status::kInvalid; + } + } + + static inline CUTLASS_HOST + Status +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + init(void const* kernel_function) +#else + init(void const* /* kernel_function */) +#endif + { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (kernel_function == nullptr) { + CUTLASS_TRACE_HOST("kernel_function is null"); + return Status::kInvalid; + } + CUTLASS_TRACE_HOST("Checking previous error state before calling cudaFuncSetAttribute"); + cudaError_t prevStatus = cudaGetLastError(); + if (prevStatus != cudaSuccess) { + fprintf(stderr, + "[ ERROR: CUDA Runtime ] %s:%d: %s\n", + __FILE__, + __LINE__, + cudaGetErrorString(prevStatus)); + return Status::kInvalid; + } + CUTLASS_TRACE_HOST("Calling cudaFuncSetAttribute"); +#endif + // This attribute was added in CUDA 11.8. + cudaError_t status = + cudaFuncSetAttribute( + kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1); + Return_Status(status); +#else + return Status::kInvalid; +#endif + } + + // This is the method we expect to use going forward + static inline CUTLASS_HOST + Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void const* kernel, + void** kernel_params) { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + if (check_cluster_dims(grid_dims, cluster_dims) != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); + return Status::kInvalid; + } + + auto init_status = init(kernel); + if (init_status != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); + return Status::kInvalid; + } + + cudaLaunchConfig_t launch_config; + launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z}; + launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z}; + launch_config.dynamicSmemBytes = smem_size; + launch_config.stream = cuda_stream; + + cudaLaunchAttribute launch_attribute[1]; + launch_attribute[0].id = cudaLaunchAttributeClusterDimension; + launch_attribute[0].val.clusterDim.x = cluster_dims.x; + launch_attribute[0].val.clusterDim.y = cluster_dims.y; + launch_attribute[0].val.clusterDim.z = cluster_dims.z; + + launch_config.attrs = launch_attribute; + launch_config.numAttrs = 1; + + CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = " + "(" << grid_dims.x << ", " << grid_dims.y << ", " << grid_dims.z << "), " + "And ClusterDims = " + "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + + cudaError_t status = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); + Return_Status(status); +#else + CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); + return Status::kInvalid; +#endif + } + +}; + +namespace detail { + +template +void* checked_addressof(Arg&& arg) { + static_assert(! std::is_rvalue_reference_v || ! std::is_const_v, "You cannot take the address of a const rvalue reference (const T&&)."); + // We use std::addressof to ensure we get the address, + // in case the type has an overloaded operator&. + // Note that this precludes `const T&&` references. + return const_cast(reinterpret_cast(std::addressof(arg))); +} + +} // namespace detail + +//! Parameters for launch_on_cluster (see below). +struct ClusterLaunchParams { + //! Grid dimensions + dim3 grid_dims{1, 1, 1}; + + //! Block dimensions + dim3 block_dims{1, 1, 1}; + + //! Cluster dimensions + dim3 cluster_dims{1, 1, 1}; + + //! Number of bytes required for the kernel's shared memory. + int smem_size_in_bytes = 0; + + //! CUDA stream on which to launch the kernel. + cudaStream_t cuda_stream = nullptr; +}; + +/// @brief Launch the kernel on the stream using cluster launch. +/// +/// @param params Cluster launch parameters (see above). +/// @param kernel_ptr Pointer to the kernel function (see example). +/// @param args Zero or more arguments to pass to the kernel. +/// +/// @tparam Args Types of the arguments passed to the kernel. +/// Don't specify this/these template argument(s) explicitly. +/// +/// @return Status::Success on success, else an error code. +/// +/// @code +/// template +/// __global__ void kernel(A a, B b, C c); +/// +/// X x = get_x(); +/// Y y = get_y(); +/// Z z = get_z(); +/// +/// void const* kernel_ptr = +/// const_cast(reinterpret_cast( +/// &kernel)); +/// auto status = launch_on_cluster( +/// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)}, +/// kernel_ptr, x, y, z); +/// @endcode +template +CUTLASS_HOST cutlass::Status +launch_kernel_on_cluster(const ClusterLaunchParams& params, + void const* kernel_ptr, + Args&& ... args) +{ + // Unfortunately, we find ourselves needing to pass in + // the parameters as an array of raw pointers. + if constexpr (sizeof...(Args) == 0) { + return cutlass::ClusterLauncher::launch( + params.grid_dims, + params.cluster_dims, + params.block_dims, + params.smem_size_in_bytes, + params.cuda_stream, + kernel_ptr, nullptr); + } + else { + void* kernel_params[sizeof...(Args)] = { + detail::checked_addressof(std::forward(args))... + }; + return cutlass::ClusterLauncher::launch( + params.grid_dims, + params.cluster_dims, + params.block_dims, + params.smem_size_in_bytes, + params.cuda_stream, + kernel_ptr, + kernel_params); + } +} + +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/complex.h b/server/punica_kernels/include/cutlass/cutlass/complex.h new file mode 100644 index 00000000..32cfa5f7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/complex.h @@ -0,0 +1,744 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#include + +#include + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/real.h" + +#include "cutlass/numeric_types.h" + +#include "cutlass/fast_math.h" + +#if !defined(__CUDACC_RTC__) +#include +#endif + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Enumeraed type describing a transformation on a complex value. +enum class ComplexTransform { + kNone, + kConjugate +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines ComplexTransform inversions +template +struct InvertComplexTransform; + +/// Invert ComplexTransform from kNone to kConjugate +template <> +struct InvertComplexTransform { + static ComplexTransform const transform = ComplexTransform::kConjugate; +}; + +/// Invert ComplexTransform from kConjugate to kNone +template <> +struct InvertComplexTransform { + static ComplexTransform const transform = ComplexTransform::kNone; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Accessors for CUDA complex types +// + +#if !defined(__CUDACC_RTC__) +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +float const &real(cuFloatComplex const &z) { return z.x; } + +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +float &real(cuFloatComplex &z) { return z.x; } + +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +double const &real(cuDoubleComplex const &z) { return z.x; } + +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +double &real(cuDoubleComplex &z) { return z.x; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +float const &imag(cuFloatComplex const &z) { return z.y; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +float &imag(cuFloatComplex &z) { return z.y; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +double const &imag(cuDoubleComplex const &z) { return z.y; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +double &imag(cuDoubleComplex &z) { return z.y; } +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Class for representing and manipulating complex numbers with conversions from built-in CUDA +/// complex types. + +template +class complex +{ + public: + /// Type alias for scalar type + using value_type = T; + + private: + // + // Data members + // + + /// Real part + T _real; + + /// Imaginary part + T _imag; + + public: + +// +// Methods +// + + /// Default constructor + complex() = default; + + /// Constructor + CUTLASS_HOST_DEVICE + complex(T r) : _real(r), _imag(T(0)) {} + + /// Constructor + CUTLASS_HOST_DEVICE + complex(T r, T i) : _real(r), _imag(i) {} + + /// Constructor + template + CUTLASS_HOST_DEVICE + complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} + + + #if !defined(__CUDACC_RTC__) + /// Conversion from cuFloatComplex + CUTLASS_HOST_DEVICE + complex(cuFloatComplex const &z) : _real(static_cast(cuCrealf(z))), _imag(static_cast(cuCimagf(z))) {} + + /// Conversion from cuDoubleComplex + CUTLASS_HOST_DEVICE + complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} + #endif + + /// Equality operator + CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { + return this->real() == rhs.real() && this->imag() == rhs.imag(); + } + + /// Inequality operator + CUTLASS_HOST_DEVICE bool operator!=(complex const &rhs) const { + return !(*this == rhs); + } + + /// Addition + template + CUTLASS_HOST_DEVICE complex operator+(complex const &rhs) const { + return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); + } + + /// Reduction into memory address. Components may update out of order. + template + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + cutlass::atomic_add reduce; + reduce(&ptr->_real, _real); + reduce(&ptr->_imag, _imag); + } + + /// Reduction into memory address. Components may update out of order. (Half specialization) + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + half2 *h2_ptr = reinterpret_cast(ptr); + half2 h2_data = reinterpret_cast(*this); + cutlass::atomic_add reduce; + reduce(h2_ptr, h2_data); + } + + /// Subtraction + template + CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { + return complex(this->real() - rhs.real(), this->imag() - rhs.imag()); + } + + /// Multiplication + template + CUTLASS_HOST_DEVICE complex operator*(complex const &rhs) const { + return complex(this->real() * rhs.real() - this->imag() * rhs.imag(), + this->real() * rhs.imag() + this->imag() * rhs.real()); + } + + /// Scalar Multiplication + template + CUTLASS_HOST_DEVICE complex operator*(A const &s) const { + return complex(this->real() * s, this->imag() * s); + } + + /// Division + template + CUTLASS_HOST_DEVICE complex operator/(complex const &rhs) const { + T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag()); + + return complex( + (real() * rhs.real() + imag() * rhs.imag()) / d, + (imag() * rhs.real() - real() * rhs.imag()) / d + ); + } + + /// Scalar Division + template + CUTLASS_HOST_DEVICE complex operator/(A const &s) const { + return complex(this->real() / s, this->imag() / s); + } + + /// Addition + template + CUTLASS_HOST_DEVICE complex &operator+=(complex const &rhs) { + *this = *this + rhs; + return *this; + } + + /// Subtraction + template + CUTLASS_HOST_DEVICE complex &operator-=(complex const &rhs) { + *this = *this - rhs; + return *this; + } + + /// Multiplication + template + CUTLASS_HOST_DEVICE complex &operator*=(complex const &rhs) { + *this = *this * rhs; + return *this; + } + + /// Scalar multiplication + template + CUTLASS_HOST_DEVICE complex &operator*=(A s) { + *this = *this * s; + return *this; + } + + /// Division + template + CUTLASS_HOST_DEVICE complex &operator/=(complex const &rhs) { + *this = *this / rhs; + return *this; + } + + /// Accesses the real part of the complex number + CUTLASS_HOST_DEVICE + T const &real() const { return _real; } + + /// Accesses the real part of the complex number + CUTLASS_HOST_DEVICE + T &real() { return _real; } + + /// Accesses the imaginary part of the complex number + CUTLASS_HOST_DEVICE + T const &imag() const { return _imag; } + + /// Accesses the imaginary part of the complex number + CUTLASS_HOST_DEVICE + T &imag() { return _imag; } + + /// Set the real part of the complex number + CUTLASS_HOST_DEVICE + void real(T real) { _real = real; } + + /// Set the imaginary part of the complex number + CUTLASS_HOST_DEVICE + void imag(T imag) { _imag = imag; } + + #if !defined(__CUDACC_RTC__) + /// Converts to cuFloatComplex + CUTLASS_HOST_DEVICE + explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); } + + /// Converts to cuDoubleComplex + CUTLASS_HOST_DEVICE + explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); } + #endif +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Accessors for complex template +// + +/// Returns the real part of the complex number +template +CUTLASS_HOST_DEVICE T const &real(complex const &z) { + return z.real(); +} + +/// Returns the real part of the complex number +template +CUTLASS_HOST_DEVICE T &real(complex &z) { + return z.real(); +} + +/// Returns the imaginary part of the complex number +template +CUTLASS_HOST_DEVICE T const &imag(complex const &z) { + return z.imag(); +} + +/// Returns the imaginary part of the complex number +template +CUTLASS_HOST_DEVICE T &imag(complex &z) { + return z.imag(); +} + +/// Returns the real part of the real number +template +CUTLASS_HOST_DEVICE T const &real(T const &r) { + return r; +} + +/// Returns the real part of the real number +template +CUTLASS_HOST_DEVICE T &real(T &r) { + return r; +} + +/// Returns the imaginary part of the real number +template +CUTLASS_HOST_DEVICE T const &imag(T const &r) { + return T(); +} + +/// Returns the imaginary part of the complex number +template +CUTLASS_HOST_DEVICE T &imag(T &r) { + return T(); +} + +// +// Output operators +// + +#if !defined(__CUDACC_RTC__) +template +std::ostream &operator<<(std::ostream &out, complex const &z) { + T _r = real(z); + T _i = imag(z); + + if (bool(_i)) { + return out << _r << "+i" << _i; + } + return out << _r; +} +#endif + +// +// Non-member operators defined for complex types +// + + +// +// Non-member functions defined for complex numbers +// + +/// Returns the magnitude of the complex number +template +CUTLASS_HOST_DEVICE T abs(complex const &z) { + return sqrt(norm(z)); +} + +/// Returns the magnitude of the complex number +template +CUTLASS_HOST_DEVICE T arg(complex const &z) { + return atan2(imag(z), real(z)); +} + +/// Returns the squared magnitude of a real number +template +CUTLASS_HOST_DEVICE T norm(T const &z) { + return z * z; +} + +/// Returns the squared magnitude of a real number +template <> +CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) { + return static_cast(z * z); +} + +/// Returns the squared magnitude of a complex number +template +CUTLASS_HOST_DEVICE double norm(complex const &z) { + return real(z) * real(z) + imag(z) * imag(z); +} + +/// Norm-accumulate calculation +template +CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { + return accumulator + static_cast(x) * static_cast(x); +} + +/// Norm accumulate specialized for complex types +template +CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { + return accumulator + static_cast(real(z)) * static_cast(real(z)) + + static_cast(imag(z)) * static_cast(imag(z)); +} + +CUTLASS_HOST_DEVICE float conj(float const &z) { + return z; +} + +CUTLASS_HOST_DEVICE double conj(double const &z) { + return z; +} + +CUTLASS_HOST_DEVICE half_t conj(half_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE int32_t conj(int32_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE uint32_t conj(uint32_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE int64_t conj(int64_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE uint64_t conj(uint64_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE int4b_t conj(int4b_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE uint4b_t conj(uint4b_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE bfloat16_t conj(bfloat16_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE uint1b_t conj(uint1b_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE tfloat32_t conj(tfloat32_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE float_e4m3_t conj(float_e4m3_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE float_e5m2_t conj(float_e5m2_t const& z) { + return z; +} + + +/// Returns the complex conjugate +template +CUTLASS_HOST_DEVICE complex conj(complex const &z) { + return complex(real(z), -imag(z)); +} + +/// Projects the complex number z onto the Riemann sphere +template +CUTLASS_HOST_DEVICE complex proj(complex const &z) { + T d = real(z) * real(z) + imag(z) * imag(z) + T(1); + return complex((T(2) * real(z)) / d, (T(2) * imag(z)) / d); +} + +/// Returns a complex number with magnitude r and phase theta +template +CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { + return complex(r * cos(theta), r * sin(theta)); +} + +/// Computes the complex exponential of z. +template +CUTLASS_HOST_DEVICE complex exp(complex const &z) { + return complex(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z))); +} + +/// Computes the log of z +template +CUTLASS_HOST_DEVICE complex log(complex const &z) { + return complex(log(abs(z)), arg(z)); +} + +/// Computes the log base 10 of z +template +CUTLASS_HOST_DEVICE complex log10(complex const &z) { + return log(z) / T(log(T(10))); +} + +/// Computes the square root of complex number z +template +CUTLASS_HOST_DEVICE complex sqrt(complex const &z) { + return sqrt(T(2)) / T(2) * + complex(sqrt(sqrt(norm(z)) + real(z)), + (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z))); +} + +/// Computes the cosine of complex z. +template +CUTLASS_HOST_DEVICE complex cos(complex const &z) { + return (exp(z) + exp(-z)) / T(2); +} + +/// Computes the sin of complex z. +template +CUTLASS_HOST_DEVICE complex sin(complex const &z) { + return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); +} + +/// Comparison +template +CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { + return true; +} + +////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex-valued type. +template +struct RealType< complex > +{ + using Type = T; + + /// Number of elements + static int const kExtent = 2; + + CUTLASS_HOST_DEVICE + static complex from_real(double x) { + return complex(static_cast(x)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +CUTLASS_HOST_DEVICE +cutlass::complex from_real >(double r) { + return cutlass::complex(half_t(r)); +} + +template <> +CUTLASS_HOST_DEVICE +cutlass::complex from_real >(double r) { + return cutlass::complex(float(r)); +} + +template <> +CUTLASS_HOST_DEVICE +cutlass::complex from_real >(double r) { + return cutlass::complex(r); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct is_complex { + static bool const value = false; +}; + +template +struct is_complex> { + static bool const value = true; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Squares with optional conversion +template +struct magnitude_squared, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()); + Output y_i = Output(lhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b.real(); + real += -a.imag() * b.imag(); + imag += a.real() * b.imag(); + imag += a.imag () * b.real(); + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, T, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + T const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b; + imag += a.imag () * b; + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + T const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a * b.real(); + imag += a * b.imag(); + + return complex{ + real, + imag + }; + } +}; + +/// Conjugate +template +struct conjugate> { + CUTLASS_HOST_DEVICE + complex operator()(complex const &a) const { + return conj(a); + } +}; + +/// Computes the square of a difference with optional conversion +template +struct magnitude_squared_difference, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs, complex rhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()) - Output(rhs.real()); + Output y_i = Output(lhs.imag()) - Output(rhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Reduces value into the data pointed to by ptr (complex specialization) +template +struct atomic_add> { + CUTLASS_DEVICE + void operator()(complex *ptr, const complex &data) + { + data.red(ptr); + } +}; + + +////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/constants.h b/server/punica_kernels/include/cutlass/cutlass/constants.h new file mode 100644 index 00000000..ded66ba2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/constants.h @@ -0,0 +1,1239 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief Boost-style constant definitions for floating-point types. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/complex.h" + +/////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace constants { + +/////////////////////////////////////////////////////////////////////////////////// + +// +// Primary templates +// + +/// Returns 1, the multiplicative identity element +template CUTLASS_HOST_DEVICE T one(); + +/// Returns 0, the additive identity element +template CUTLASS_HOST_DEVICE T zero(); + +/// Returns 2 +template CUTLASS_HOST_DEVICE T two(); + +/// Returns pi, approximately 3.141 +template CUTLASS_HOST_DEVICE T pi(); + +/// Returns 2 * pi +template CUTLASS_HOST_DEVICE T two_pi(); + +/// Returns pi / 2 +template CUTLASS_HOST_DEVICE T half_pi(); + +/// Returns sqrt(pi) +template CUTLASS_HOST_DEVICE T root_pi(); + +/// Returns sqrt(pi / 2) +template CUTLASS_HOST_DEVICE T root_half_pi(); + +/// Returns sqrt(2 * pi) +template CUTLASS_HOST_DEVICE T root_two_pi(); + +/// Returns sqrt(ln(4)) +template CUTLASS_HOST_DEVICE T root_ln_four(); + +/// Returns e, approximately 2.718... +template CUTLASS_HOST_DEVICE T e(); + +/// Returns (1/2) +template CUTLASS_HOST_DEVICE T half(); + +/// Returns sqrt(2), approximately 1.414... +template CUTLASS_HOST_DEVICE T root_two(); + +/// Returns sqrt(2)/2, approximately 0.707... +template CUTLASS_HOST_DEVICE T half_root_two(); + +/// Returns ln(2), approximately 0.693... +template CUTLASS_HOST_DEVICE T ln_two(); + +/// Returns ln(ln(2)), approximately -0.3665... +template CUTLASS_HOST_DEVICE T ln_ln_two(); + +/// Returns 1/3, approximately 0.333... +template CUTLASS_HOST_DEVICE T third(); + +/// Returns 2/3, approximately 0.666... +template CUTLASS_HOST_DEVICE T twothirds(); + +/// Returns pi - 3, approximately 0.1416... +template CUTLASS_HOST_DEVICE T pi_minus_three(); + +/// Returns 4 - pi, approximately 0.858... +template CUTLASS_HOST_DEVICE T four_minus_pi(); + + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for double + +/// Returns 1, the multiplicative identity element (specialization for double) +template <> CUTLASS_HOST_DEVICE double one() { + uint64_t bits = 0x3ff0000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), double()); +} + +/// Returns 0, the additive identity element (specialization for double) +template <> CUTLASS_HOST_DEVICE double zero() { + uint64_t bits = 0x0ull; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), double()); +} + +/// Returns 2 (specialization for double) +template <> CUTLASS_HOST_DEVICE double two() { + uint64_t bits = 0x4000000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), double()); +} + +/// Returns pi, approximately 3.141 (specialization for double) +template <> CUTLASS_HOST_DEVICE double pi() { + uint64_t bits = 0x400921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), double()); +} + +/// Returns 2 * pi (specialization for double) +template <> CUTLASS_HOST_DEVICE double two_pi() { + uint64_t bits = 0x401921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), double()); +} + +/// Returns pi / 2 (specialization for double) +template <> CUTLASS_HOST_DEVICE double half_pi() { + uint64_t bits = 0x3ff921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), double()); +} + +/// Returns sqrt(pi) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_pi() { + uint64_t bits = 0x3ffc5bf891b4ef6aull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), double()); +} + +/// Returns sqrt(pi / 2) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_half_pi() { + uint64_t bits = 0x3ff40d931ff62705ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), double()); +} + +/// Returns sqrt(2 * pi) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_two_pi() { + uint64_t bits = 0x40040d931ff62705ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), double()); +} + +/// Returns sqrt(ln(4)) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_ln_four() { + uint64_t bits = 0x3ff2d6abe44afc43ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), double()); +} + +/// Returns e, approximately 2.718... (specialization for double) +template <> CUTLASS_HOST_DEVICE double e() { + uint64_t bits = 0x4005bf0a8b145769ull; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), double()); +} + +/// Returns (1/2) (specialization for double) +template <> CUTLASS_HOST_DEVICE double half() { + uint64_t bits = 0x3fe0000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), double()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_two() { + uint64_t bits = 0x3ff6a09e667f3bcdull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), double()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for double) +template <> CUTLASS_HOST_DEVICE double half_root_two() { + uint64_t bits = 0x3fe6a09e667f3bcdull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), double()); +} + +/// Returns ln(2), approximately 0.693... (specialization for double) +template <> CUTLASS_HOST_DEVICE double ln_two() { + uint64_t bits = 0x3fe62e42fefa39efull; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), double()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for double) +template <> CUTLASS_HOST_DEVICE double ln_ln_two() { + uint64_t bits = 0xbfd774f29bdd6b9full; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), double()); +} + +/// Returns 1/3, approximately 0.333... (specialization for double) +template <> CUTLASS_HOST_DEVICE double third() { + uint64_t bits = 0x3fd5555555555555ull; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), double()); +} + +/// Returns 2/3, approximately 0.666... (specialization for double) +template <> CUTLASS_HOST_DEVICE double twothirds() { + uint64_t bits = 0x3fe5555555555555ull; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), double()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for double) +template <> CUTLASS_HOST_DEVICE double pi_minus_three() { + uint64_t bits = 0x3fc21fb54442d180ull; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), double()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for double) +template <> CUTLASS_HOST_DEVICE double four_minus_pi() { + uint64_t bits = 0x3feb7812aeef4ba0ull; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), double()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for float + +/// Returns 1, the multiplicative identity element (specialization for float) +template <> CUTLASS_HOST_DEVICE float one() { + uint32_t bits = 0x3f800000u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), float()); +} + +/// Returns 0, the additive identity element (specialization for float) +template <> CUTLASS_HOST_DEVICE float zero() { + uint32_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), float()); +} + +/// Returns 2 (specialization for float) +template <> CUTLASS_HOST_DEVICE float two() { + uint32_t bits = 0x40000000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), float()); +} + +/// Returns pi, approximately 3.141 (specialization for float) +template <> CUTLASS_HOST_DEVICE float pi() { + uint32_t bits = 0x40490fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), float()); +} + +/// Returns 2 * pi (specialization for float) +template <> CUTLASS_HOST_DEVICE float two_pi() { + uint32_t bits = 0x40c90fdbu; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), float()); +} + +/// Returns pi / 2 (specialization for float) +template <> CUTLASS_HOST_DEVICE float half_pi() { + uint32_t bits = 0x3fc90fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), float()); +} + +/// Returns sqrt(pi) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_pi() { + uint32_t bits = 0x3fe2dfc5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), float()); +} + +/// Returns sqrt(pi / 2) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_half_pi() { + uint32_t bits = 0x3fa06c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), float()); +} + +/// Returns sqrt(2 * pi) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_two_pi() { + uint32_t bits = 0x40206c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), float()); +} + +/// Returns sqrt(ln(4)) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_ln_four() { + uint32_t bits = 0x3f96b55fu; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), float()); +} + +/// Returns e, approximately 2.718... (specialization for float) +template <> CUTLASS_HOST_DEVICE float e() { + uint32_t bits = 0x402df854u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), float()); +} + +/// Returns (1/2) (specialization for float) +template <> CUTLASS_HOST_DEVICE float half() { + uint32_t bits = 0x3f000000u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), float()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_two() { + uint32_t bits = 0x3fb504f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), float()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for float) +template <> CUTLASS_HOST_DEVICE float half_root_two() { + uint32_t bits = 0x3f3504f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), float()); +} + +/// Returns ln(2), approximately 0.693... (specialization for float) +template <> CUTLASS_HOST_DEVICE float ln_two() { + uint32_t bits = 0x3f317218u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), float()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for float) +template <> CUTLASS_HOST_DEVICE float ln_ln_two() { + uint32_t bits = 0xbebba795u; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), float()); +} + +/// Returns 1/3, approximately 0.333... (specialization for float) +template <> CUTLASS_HOST_DEVICE float third() { + uint32_t bits = 0x3eaaaaabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), float()); +} + +/// Returns 2/3, approximately 0.666... (specialization for float) +template <> CUTLASS_HOST_DEVICE float twothirds() { + uint32_t bits = 0x3f2aaaabu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), float()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for float) +template <> CUTLASS_HOST_DEVICE float pi_minus_three() { + uint32_t bits = 0x3e10fdaau; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), float()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for float) +template <> CUTLASS_HOST_DEVICE float four_minus_pi() { + uint32_t bits = 0x3f5bc095u; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), float()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for tfloat32_t + +/// Returns 1, the multiplicative identity element (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t one() { + uint32_t bits = 0x3f801000u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), tfloat32_t()); +} + +/// Returns 0, the additive identity element (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t zero() { + uint32_t bits = 0x1000u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), tfloat32_t()); +} + +/// Returns 2 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t two() { + uint32_t bits = 0x40001000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), tfloat32_t()); +} + +/// Returns pi, approximately 3.141 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t pi() { + uint32_t bits = 0x40491fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), tfloat32_t()); +} + +/// Returns 2 * pi (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi() { + uint32_t bits = 0x40c91fdbu; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), tfloat32_t()); +} + +/// Returns pi / 2 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi() { + uint32_t bits = 0x3fc91fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), tfloat32_t()); +} + +/// Returns sqrt(pi) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi() { + uint32_t bits = 0x3fe2efc5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), tfloat32_t()); +} + +/// Returns sqrt(pi / 2) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi() { + uint32_t bits = 0x3fa07c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), tfloat32_t()); +} + +/// Returns sqrt(2 * pi) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi() { + uint32_t bits = 0x40207c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), tfloat32_t()); +} + +/// Returns sqrt(ln(4)) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four() { + uint32_t bits = 0x3f96c55fu; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), tfloat32_t()); +} + +/// Returns e, approximately 2.718... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t e() { + uint32_t bits = 0x402e0854u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), tfloat32_t()); +} + +/// Returns (1/2) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half() { + uint32_t bits = 0x3f001000u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), tfloat32_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_two() { + uint32_t bits = 0x3fb514f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), tfloat32_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two() { + uint32_t bits = 0x3f3514f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), tfloat32_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two() { + uint32_t bits = 0x3f318218u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), tfloat32_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two() { + uint32_t bits = 0xbebbb795u; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), tfloat32_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t third() { + uint32_t bits = 0x3eaabaabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), tfloat32_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds() { + uint32_t bits = 0x3f2abaabu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), tfloat32_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three() { + uint32_t bits = 0x3e110daau; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), tfloat32_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi() { + uint32_t bits = 0x3f5bd095u; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), tfloat32_t()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for half_t + +/// Returns 1, the multiplicative identity element (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t one() { + uint16_t bits = 0x3c00u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), half_t()); +} + +/// Returns 0, the additive identity element (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t zero() { + uint16_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), half_t()); +} + +/// Returns 2 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t two() { + uint16_t bits = 0x4000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), half_t()); +} + +/// Returns pi, approximately 3.141 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t pi() { + uint16_t bits = 0x4248u; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), half_t()); +} + +/// Returns 2 * pi (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t two_pi() { + uint16_t bits = 0x4648u; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), half_t()); +} + +/// Returns pi / 2 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half_pi() { + uint16_t bits = 0x3e48u; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), half_t()); +} + +/// Returns sqrt(pi) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_pi() { + uint16_t bits = 0x3f17u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), half_t()); +} + +/// Returns sqrt(pi / 2) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_half_pi() { + uint16_t bits = 0x3d03u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), half_t()); +} + +/// Returns sqrt(2 * pi) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_two_pi() { + uint16_t bits = 0x4103u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), half_t()); +} + +/// Returns sqrt(ln(4)) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_ln_four() { + uint16_t bits = 0x3cb6u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), half_t()); +} + +/// Returns e, approximately 2.718... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t e() { + uint16_t bits = 0x4170u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), half_t()); +} + +/// Returns (1/2) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half() { + uint16_t bits = 0x3800u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), half_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_two() { + uint16_t bits = 0x3da8u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), half_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half_root_two() { + uint16_t bits = 0x39a8u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), half_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t ln_two() { + uint16_t bits = 0x398cu; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), half_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t ln_ln_two() { + uint16_t bits = 0xb5ddu; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), half_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t third() { + uint16_t bits = 0x3555u; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), half_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t twothirds() { + uint16_t bits = 0x3955u; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), half_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t pi_minus_three() { + uint16_t bits = 0x3088u; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), half_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t four_minus_pi() { + uint16_t bits = 0x3adeu; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), half_t()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for bfloat16_t + +/// Returns 1, the multiplicative identity element (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t one() { + uint16_t bits = 0x3f80u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), bfloat16_t()); +} + +/// Returns 0, the additive identity element (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t zero() { + uint16_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), bfloat16_t()); +} + +/// Returns 2 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t two() { + uint16_t bits = 0x4000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), bfloat16_t()); +} + +/// Returns pi, approximately 3.141 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t pi() { + uint16_t bits = 0x4049u; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), bfloat16_t()); +} + +/// Returns 2 * pi (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi() { + uint16_t bits = 0x40c9u; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), bfloat16_t()); +} + +/// Returns pi / 2 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi() { + uint16_t bits = 0x3fc9u; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), bfloat16_t()); +} + +/// Returns sqrt(pi) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi() { + uint16_t bits = 0x3fe3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), bfloat16_t()); +} + +/// Returns sqrt(pi / 2) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi() { + uint16_t bits = 0x3fa0u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), bfloat16_t()); +} + +/// Returns sqrt(2 * pi) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi() { + uint16_t bits = 0x4020u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), bfloat16_t()); +} + +/// Returns sqrt(ln(4)) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four() { + uint16_t bits = 0x3f97u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), bfloat16_t()); +} + +/// Returns e, approximately 2.718... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t e() { + uint16_t bits = 0x402eu; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), bfloat16_t()); +} + +/// Returns (1/2) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half() { + uint16_t bits = 0x3f00u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), bfloat16_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_two() { + uint16_t bits = 0x3fb5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), bfloat16_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two() { + uint16_t bits = 0x3f35u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), bfloat16_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two() { + uint16_t bits = 0x3f31u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), bfloat16_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two() { + uint16_t bits = 0xbebcu; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), bfloat16_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t third() { + uint16_t bits = 0x3eabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), bfloat16_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds() { + uint16_t bits = 0x3f2bu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), bfloat16_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three() { + uint16_t bits = 0x3e11u; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), bfloat16_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi() { + uint16_t bits = 0x3f5cu; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), bfloat16_t()); +} +/////////////////////////////////////////////////////////////////////////////////// + +} // namespace constants +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/collective/builders/sm90_common.inl b/server/punica_kernels/include/cutlass/cutlass/conv/collective/builders/sm90_common.inl new file mode 100644 index 00000000..526db83e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/collective/builders/sm90_common.inl @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/layout/tensor.h" +#include "cutlass/arch/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Maps a rank-1 cute::Shape<> representing the cluster shape on to the IM2COL TMA atom that should be used with it +template +constexpr auto +sm90_cluster_shape_to_im2col_tma_atom(UnimodalClusterShape unimodal_cluster_shape) { + static_assert(cute::rank(unimodal_cluster_shape) == 1, + "Use this function to figure out TMA for each mode individually."); + + if constexpr (cute::size(unimodal_cluster_shape) == 1) { + return cute::SM90_TMA_LOAD_IM2COL{}; + } + else { + return cute::SM90_TMA_LOAD_IM2COL_MULTICAST{}; + } +} + +// Collective tile traits struct that serves as a type list containing a tensor's mem layouts and atoms for the +template< + class GmemTiledCopy_, + class SmemLayout_, + class SmemCopyAtom_ = void +> +struct Sm90ImplicitGemmTileTraits { + using GmemTiledCopy = GmemTiledCopy_; + using SmemLayout = SmemLayout_; + using SmemCopyAtom = SmemCopyAtom_; +}; + +// Accepts a cutlass::layout::Tensor tag and computes the corresponding spatial dimension count +template +constexpr int +gmem_layout_tags_to_spatial_dims() { + static_assert(cute::is_same_v); + if constexpr (cute::is_same_v) { + return 1; + } + else if constexpr (cute::is_same_v) { + return 2; + } + else if constexpr (cute::is_same_v) { + return 3; + } + else { + static_assert(cutlass::detail::dependent_false); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective::detail + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/collective/builders/sm90_gmma_builder.inl b/server/punica_kernels/include/cutlass/cutlass/conv/collective/builders/sm90_gmma_builder.inl new file mode 100644 index 00000000..e5ebcefc --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/collective/builders/sm90_gmma_builder.inl @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS_FPROP +template < + conv::Operator ConvOp, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t || + cute::is_same_v || + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(cutlass::gemm::collective::detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + // For fprop, majorA = K, major B = K; + // For wgrad, majorA = MN, major B = MN; + // For dgrad, majorA = K, major B = MN; + static constexpr cute::GMMA::Major GmmaMajorA = + (ConvOp == conv::Operator::kWgrad) ? cute::GMMA::Major::MN : cute::GMMA::Major::K; + static constexpr cute::GMMA::Major GmmaMajorB = + (ConvOp == conv::Operator::kFprop) ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + // For wgrad kernel, tensor A uses tma tiled mode and tensor B uses tma im2col mode. + using GmemTiledCopyA = cute::conditional_t(ClusterShape_MNK{}))), + decltype(cutlass::conv::collective::detail::sm90_cluster_shape_to_im2col_tma_atom(cute::shape<1>(ClusterShape_MNK{})))>; + using GmemTiledCopyB = cute::conditional_t(ClusterShape_MNK{}))), + decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(cute::shape<0>(ClusterShape_MNK{})))>; + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}), + Step<_2,_1,_3>{})); + + constexpr static int NumSpatialDimensions = cutlass::conv::collective::detail::gmem_layout_tags_to_spatial_dims(); + + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = CollectiveConv< + DispatchPolicy, + TileShape_MNK, + ElementA, + ElementB, + TiledMma, + detail::Sm90ImplicitGemmTileTraits, + detail::Sm90ImplicitGemmTileTraits + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA auto kernel schedule +template < + conv::Operator ConvOp, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + +/* +#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) + // Cooperative schedule performs best for CUDA Toolkits with version >= 12.1 + + // For TileShape_M == 64, choosing KernelTmaWarpSpecialized as the KernelSchedule + // Since KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 + using KernelWarpSpecializedSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + KernelImplicitTmaWarpSpecializedSm90PingPong, KernelImplicitTmaWarpSpecializedSm90Cooperative>; +#else + using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90; +#endif +*/ + using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90; + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelWarpSpecializedSchedule + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/collective/collective_builder.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/collective/collective_builder.hpp new file mode 100644 index 00000000..9d6a16c0 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/collective/collective_builder.hpp @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/conv/collective/collective_conv.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify stage counts or dispatch to automatic computation of stage count +template +struct StageCount { + static constexpr int value = num_stages; + + StageCount() = default; + explicit StageCount(cute::Int) {} +}; + +template +struct StageCountAutoCarveout { + static constexpr int bytes = carveout_bytes; + + StageCountAutoCarveout() = default; + explicit StageCountAutoCarveout(cute::Int) {} +}; + +// Used to automatically let the builder pick the kernel schedule. +// Can be overridden with kernel schedule tags in cutlass/conv/dispatch_policy.hpp +struct KernelScheduleAuto {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + conv::Operator, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_gmma_builder.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/collective/collective_conv.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/collective/collective_conv.hpp new file mode 100644 index 00000000..d187b5ec --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/collective/collective_conv.hpp @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/conv/collective/detail.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class ElementB, + class TiledMma, + class TileTraitsA, + class TileTraitsB +> +struct CollectiveConv { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/collective/detail.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/collective/detail.hpp new file mode 100644 index 00000000..0f192209 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/collective/detail.hpp @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Construct the stride types for conv collectives based on the dispatch policy, strides 64b by default +template +constexpr auto +sm90_dispatch_policy_to_stride_A() { + if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { + // Maps to modes ((w,n), C) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((w,h,n), C) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((w,h,d,n), C) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Int<1>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { + // Maps to modes (k, nq/npq/nzpq) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1 || + DispatchPolicy::NumSpatialDimensions == 2 || + DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, int64_t>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { + // Maps to modes ((q,n), K) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((q,p,n), K) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((q,p,z,n), K) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Int<1>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); + } +} + +// Construct the stirde types for conv collectives based on the dispatch policy, strides 64b by default +template +constexpr auto +sm90_dispatch_policy_to_stride_B() { + if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { + // Maps to modes (k, (C,s)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, int64_t>>{}; + } + // Maps to modes (k, (C,s,r)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, int64_t, int64_t>>{}; + } + // Maps to modes (k, (C,s,r,t)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, int64_t, int64_t, int64_t>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { + // Maps to modes (C, (w,n)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Stride>{}; + } + // Maps to modes (C, (w,h,n)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Stride>{}; + } + // Maps to modes (C, (w,h,d,n)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Stride>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { + // Maps to modes (C, (k,s)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, cute::Stride>{}; + } + // Maps to modes (C, (k,s,r)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, cute::Stride>{}; + } + // Maps to modes (C, (k,s,r,t)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, cute::Stride>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Compute the lower/near corner, returning it as a cute::array in [W,H,D] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_lower_corner_whd(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array lower{}; + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = -1 * problem_shape.lower_padding[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + return lower; +} + +// Computes the upper/far corner, returning it as a cute::array in [W,H,D] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_upper_corner_whd(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array upper{}; + if constexpr (ConvOp == conv::Operator::kFprop) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - + (problem_shape.shape_C[i+1] - 1) * problem_shape.dilation[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i] + problem_shape.shape_C[i+1] - problem_shape.shape_A[i+1]; + }); + } + return upper; +} + +// Compute the lower/near corner of (t,r,s), returning it as a cute::array in [S,R,T] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_lower_srt(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array lower{}; + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = 0; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + return lower; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective::detail diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp new file mode 100644 index 00000000..f8a02336 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -0,0 +1,616 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor_predicate.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_im2col.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/packed_stride.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + conv::Operator ConvOp, + int Stages, + int NumSpatialDims, + class ClusterShape, + class KernelSchedule, + int PipelineAsyncMmaStages, + class TileShape_, + class ElementA_, + class ElementB_, + class TiledMma_, + class TileTraitsA_, + class TileTraitsB_> +struct CollectiveConv< + MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>, + TileShape_, + ElementA_, + ElementB_, + TiledMma_, + TileTraitsA_, + TileTraitsB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>; + using TileShape = TileShape_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy; + using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy; + using SmemLayoutA = typename TileTraitsA_::SmemLayout; + using SmemLayoutB = typename TileTraitsB_::SmemLayout; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions; + static constexpr int NumTensorDimensions = NumSpatialDimensions + 2; + // Deduce the kernel-facing stride tuple types based on the dispatch policy + // (which is a function of the number of spatial dimensions, the algorithm, etc.) + using StrideA = decltype(detail::sm90_dispatch_policy_to_stride_A()); + using StrideB = decltype(detail::sm90_dispatch_policy_to_stride_B()); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // TODO: move pipeline mode tiling into the collective setup phase instead + static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); + static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); + static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); + + static_assert(rank(SmemLayoutB{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); + static_assert((size<1>(TileShape{}) == size<0>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); + static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + // The tma load mode of wgrad is tiled for tensor A and im2col for tensor B while the tma load mode of fprop and dgrad + // kernel is im2col for tensor A and tiled for tensor B. + static_assert((ConvOp == conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)) + || (ConvOp != conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopyA - invalid SM90 TMA copy atom specified."); + static_assert((ConvOp == conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)) + || (ConvOp != conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopyB - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(InternalElementA)))+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(InternalElementB))); + + // Host side kernel arguments + struct Arguments { + using ProblemShape = ConvProblemShape; + ProblemShape problem_shape{}; + ElementA const* ptr_A{nullptr}; + ElementB const* ptr_B{nullptr}; + }; + +private: + // Note that for fprop and dgrad kernel, the tma load mode is im2col for tensor A and tiled for + // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor + // B since operand A, B is swapped. + + // Get tma_load_a instantce. + template + static constexpr auto + get_tma_load_a_instance(TensorA const& tensor_a, typename Arguments::ProblemShape const& problem_shape) { + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + // The calculation of gbasis strides for dgrad kernel needs perform negate for dilation values. + cute::array stride_srt{}; + for (int i = 0; i < NumSpatialDimensions; ++i) { + stride_srt[i] = ConvOp == conv::Operator::kDgrad ? + -problem_shape.dilation[NumSpatialDimensions-1-i] : + problem_shape.dilation[NumSpatialDimensions-1-i]; + } + + return make_im2col_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_0{}), + product_each(shape(SmemLayoutA{}(_,_,_0{}))), + size<1>(ClusterShape{}), + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + shape(stride_srt)); + } + // TMA tiled mode for tensor A in wgrad kernel. + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_0{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); + } + } + + // Get tma_load_b instantce. + template + static constexpr auto + get_tma_load_b_instance(TensorB const& tensor_b, typename Arguments::ProblemShape const& problem_shape) { + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + return make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_0{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); + } + // TMA im2col mode for tensor B in wgrad kernel. + else if constexpr (ConvOp == conv::Operator::kWgrad) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + return make_im2col_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_0{}), + product_each(shape(SmemLayoutB{}(_,_,_0{}))), + size<0>(ClusterShape{}), + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + cute::reverse(shape(problem_shape.dilation))); + } + } + +public: + + // Device side kernel params + struct Params { + using _Submode = decltype(take<0,NumTensorDimensions-1>(typename Arguments::ProblemShape::TensorExtent{})); + using ProblemShape = cute::conditional_t, + Shape<_Submode, int, _Submode>>; + + // Assumption: StrideA is congruent with Problem_MK + // Select TMA load type according to convolution operator. + using TensorShapeA = cute::conditional_t; + + using TensorShapeB = cute::conditional_t; + + using TMA_A = decltype(get_tma_load_a_instance( + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + make_layout(TensorShapeA{}, StrideA{})), + ConvProblemShape{})); + + using TMA_B = decltype(get_tma_load_b_instance( + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + make_layout(TensorShapeB{}, StrideB{})), + ConvProblemShape{})); + + // Members + TMA_A tma_load_a; + TMA_B tma_load_b; + ProblemShape problem_shape; + }; + + // + // Methods + // + + // Lowers the host side user facing arguments to the kernel facing lauch params + static constexpr Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple + // tma desc creation depends on the original untransformed domain. + + // A extents. + auto shape_A_orig = args.problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = args.problem_shape.get_shape_B(); + + // Fill inferred cute strides from flat stride arrays + auto dA = make_cute_packed_stride(StrideA{}, args.problem_shape.stride_A, ConvOp); + auto dB = make_cute_packed_stride(StrideB{}, args.problem_shape.stride_B, ConvOp); + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); + Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); + + auto tma_load_a = get_tma_load_a_instance(tensor_a, args.problem_shape); + auto tma_load_b = get_tma_load_b_instance(tensor_b, args.problem_shape); + + auto problem_shape_mnk = args.problem_shape.get_transformed_problem_shape_MNK(); + + return { + tma_load_a, + tma_load_b, + problem_shape_mnk + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + Arguments const& args) { + // Activation and Filter channel mode extents much match + bool implementable = true; + // channel mode is major + implementable &= args.problem_shape.stride_A[NumTensorDimensions-1] == 1; + implementable &= args.problem_shape.stride_B[NumTensorDimensions-1] == 1; + + constexpr int tma_alignment_bits = 128; + // A extents. + auto shape_A_orig = args.problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = args.problem_shape.get_shape_B(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(shape_A_orig, StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(shape_B_orig, StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + // Check valid padding values for TMA_LOAD_IM2COL + constexpr int padding_limit = (ProblemShape::RankS == 1) ? 65536 : (ProblemShape::RankS == 2 ? 256 : 16); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && problem_shape.lower_padding[i] <= padding_limit && problem_shape.lower_padding[i] >= 0; + implementable = implementable && problem_shape.upper_padding[i] <= padding_limit && problem_shape.upper_padding[i] >= 0; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + + if (problem_shape.groups > 1) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); + return false; + } + + return true; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TMA_LOAD_A, + class TensorB, class TMA_LOAD_B, + class KTileIterator + > + CUTLASS_DEVICE void + load(MainloopPipeline pipeline, + PipelineState smem_pipe_producer_state, + TensorA const& gA, TMA_LOAD_A& tma_load_a, + TensorB const& gB, TMA_LOAD_B& tma_load_b, + KTileIterator k_tile_iter, int k_tile_count, + int thad_idx, + TensorStorage& shared_tensors) { + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + dim3 cluster_local_block_id = cute::block_id_in_cluster(); + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v || + cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v || + cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_producer_state for _writing_ + pipeline.producer_acquire(smem_pipe_producer_state); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state); + + int write_stage = smem_pipe_producer_state.index(); + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_producer_state + ++smem_pipe_producer_state; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_producer_state) { + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (warp_idx_in_warp_group == 0 and lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_producer_state); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_consumer_state, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_consumer_state; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { + // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_consumer_state); + + int read_stage = smem_pipe_consumer_state.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_consumer_state; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_consumer_state); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_consumer_state.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_producer_state is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_consumer_state and smem_pipe_release + ++smem_pipe_consumer_state; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/conv2d_problem_size.h b/server/punica_kernels/include/cutlass/cutlass/conv/conv2d_problem_size.h new file mode 100644 index 00000000..c66738d0 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/conv2d_problem_size.h @@ -0,0 +1,653 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing convolution problem sizes. + + Conv2dProblem desciption: + activation (NHWC), + filter (KRSC), + output (NPQK), + pading (pad_h, pad_w), + stride (stride_h, stride_w), + dilation (dilation_h, dilation_w). + + Free functions to map: + Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) + Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) + Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) +*/ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm_enumerated_types.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/functional.h" + +namespace cutlass { +namespace conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Problem size structure +struct Conv2dProblemSize { + + // Conv2d strictly problem size parameters + int N, H, W, C, P, Q, K, R, S; + int pad_h, pad_w; + int stride_h, stride_w; + int dilation_h, dilation_w; + Mode mode; + + // Conv2d implementation-related parameters + int split_k_slices; + int groups; + + // + // Methods + // + +public: + CUTLASS_HOST_DEVICE + Conv2dProblemSize(): + N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0), + pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), + mode(Mode::kConvolution), split_k_slices(1), groups(1) { } + + /// Constructor for default padding, stride, dilation, and split-K + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + int N, + int H, + int W, + int C, + int P, + int Q, + int K, + int R, + int S, + Mode mode + ): + N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), + pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), + mode(mode), split_k_slices(1), groups (1) { } + + /// Constructor + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + int N, + int H, + int W, + int C, + int K, + int R, + int S, + int P, + int Q, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + Mode mode, + int split_k_slices = 1, + int groups = 1 + ): + N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), + pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), + dilation_h(dilation_h), dilation_w(dilation_w), + mode(mode), split_k_slices(split_k_slices), groups (groups) { } + + /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord + // set user-defined output size and sets P and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + cutlass::Tensor4DCoord input_size, // NHWC + cutlass::Tensor4DCoord filter_size, // KRSC + cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ + cutlass::MatrixCoord stride, // stride_h, stride_w + cutlass::MatrixCoord dilation, // dilation_h, dilation_w + cutlass::Tensor4DCoord output_size, // NPQK + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), + K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), + pad_h(padding[0]), pad_w(padding[2]), + stride_h(stride.row()), stride_w(stride.column()), + dilation_h(dilation.row()), dilation_w(dilation.column()), + mode(mode), split_k_slices(split_k_slices), groups(groups) {} + + /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord + // computes output size and sets P and Q (skip output from ctor arguments) + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + cutlass::Tensor4DCoord input_size, // NHWC + cutlass::Tensor4DCoord filter_size, // KRSC + cutlass::Tensor4DCoord padding, // pad_h, upper_pad_h, pad_w, upper_pad_w + cutlass::MatrixCoord stride, // stride_h, stride_w + cutlass::MatrixCoord dilation, // dilation_h, dilation_w + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), + pad_h(padding[0]), pad_w(padding[2]), + stride_h(stride.row()), stride_w(stride.column()), + dilation_h(dilation.row()), dilation_w(dilation.column()), + mode(mode), split_k_slices(split_k_slices), groups(groups) { + // set output P and Q + P = ((H + pad_h + padding[1] - R * dilation_h) / stride_h) + 1; + Q = ((W + pad_w + padding[3] - S * dilation_w) / stride_w) + 1; + } + + /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord + // set user-defined output size and sets P and Q (skip padding, striding, and dilation) + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + cutlass::Tensor4DCoord input_size, // NHWC + cutlass::Tensor4DCoord filter_size, // KRSC + cutlass::Tensor4DCoord output_size, // NPQK + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), + K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), + pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), + dilation_h(1), dilation_w(1), + mode(mode), split_k_slices(split_k_slices), groups(groups) {} + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) { + Conv2dProblemSize tmp(*this); + tmp.mode = mode_; + return tmp; + } + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv2dProblemSize reset_split_k_slices(int split_k_slices_) { + Conv2dProblemSize tmp(*this); + tmp.split_k_slices = split_k_slices_; + return tmp; + } + + /// Equality operator (ignores mode and split_k_slice) + CUTLASS_HOST_DEVICE + bool operator==(Conv2dProblemSize const &conv) const { + return ( + (N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) && + (K == conv.K) && (R == conv.R) && (S == conv.S) && + (P == conv.P) && (Q == conv.Q) && + (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && + (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && + (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) + ); + } + + /// Inequality operator + CUTLASS_HOST_DEVICE + bool operator!=(Conv2dProblemSize const &rhs) const { + return !(*this == rhs); + } + + /// Returns activation extent as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord activation_extent() const { + + return cutlass::Tensor4DCoord ({N, H, W, C}); + } + + /// Returns filter extent as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord filter_extent() const { + + return cutlass::Tensor4DCoord ({K, R, S, C / groups}); + } + + /// Returns output extent as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord output_extent() const { + + return cutlass::Tensor4DCoord ({N, P, Q, K}); + } + + /// Returns activation size in number of elements + CUTLASS_HOST_DEVICE + int64_t activation_size() const { + + return (N * H * W * C); + } + + /// Returns filter size in number of elements + CUTLASS_HOST_DEVICE + int64_t filter_size() const { + + return (K * R * S * C / groups); + } + + /// Returns output size in number of elements + CUTLASS_HOST_DEVICE + int64_t output_size() const { + + return (N * P * Q * K); + } + + /// Returns padding as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord padding() const { + + return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w}); + } + + /// Returns stride as MatrixCoord + CUTLASS_HOST_DEVICE + cutlass::MatrixCoord stride() const { + + return cutlass::MatrixCoord ({stride_h, stride_w}); + } + + /// Returns dilation as MatrixCoord + CUTLASS_HOST_DEVICE + cutlass::MatrixCoord dilation() const { + + return cutlass::MatrixCoord ({dilation_h, dilation_w}); + } + + ///////////////////////////////////////////////////////////////// + // Methods used for strided dgrad implementation + ///////////////////////////////////////////////////////////////// + /// Number of filter r positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_r(int r) const { + return ((R - r + stride_h - 1) / stride_h); + } + + /// Number of filter s positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_s(int s) const { + return ((S - s + stride_w - 1) / stride_w); + } + + /// Number of filter positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_positions(int r, int s) const { + return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ImplicitGemm helper functions // +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Determine the problem size of the implicit GEMM operation +CUTLASS_HOST_DEVICE +cutlass::gemm::GemmCoord implicit_gemm_problem_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + // Compute problem size + switch (conv_operator) { + case Operator::kFprop: + return gemm::GemmCoord( + problem_size.N * problem_size.P * problem_size.Q, + problem_size.K, + problem_size.R * problem_size.S * problem_size.C / problem_size.groups + ); + case Operator::kDgrad: + return gemm::GemmCoord( + problem_size.N * problem_size.H * problem_size.W, + problem_size.C, + problem_size.R * problem_size.S * problem_size.K + ); + case Operator::kWgrad: + return gemm::GemmCoord( + problem_size.K, + problem_size.R * problem_size.S * problem_size.C, + problem_size.N * problem_size.P * problem_size.Q + ); + default: + break; + } + return gemm::GemmCoord(); +} + +// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int iterations = 0; + + if (group_mode == GroupMode::kNone) { + + if (algorithm == IteratorAlgorithm::kFixedChannels) { + + int positions_per_iteration = threadblock_K / problem_size.C; + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; + break; + + default: + break; + } + } + else if (algorithm == IteratorAlgorithm::kFewChannels) { + + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; + break; + + default: + break; + } + } + else { + int elements_per_split_k_slice = 0; + + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; + + default: + break; + } + } + + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } else { // Group conv + + int channels_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups + if (problem_size.groups != 1) { + if (k_per_group < threadblock_N) { + iterations *= threadblock_N / k_per_group; + } + } + break; + + default: + break; + } + } else if (algorithm == IteratorAlgorithm::kOptimized) { + // Current optimized iterator only support GroupMode::kSingleGroup + if (group_mode == GroupMode::kSingleGroup) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } + + } + + return iterations; +} + + +template +CUTLASS_HOST_DEVICE +int depthwise_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int n = problem_size.N; + int p = (problem_size.P + Output_P - 1) / Output_P; + int q = (problem_size.Q + Output_Q - 1) / Output_Q; + + int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + return iterations; +} + + +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations_per_channel( + Operator conv_operator, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + + int iterations = 0; //0 means not applicable + if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S; + break; + + case Operator::kDgrad: + iterations = problem_size.R * problem_size.S; + break; + + default: + break; + } + } + return iterations; +} + +//////////////////////////////////////////////////////////////////////////////// +// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) +//////////////////////////////////////////////////////////////////////////////// +/// Returns ImplicitGemm tensor A extent as Tensor4DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); + default : break; + } + return cutlass::Tensor4DCoord(); +} + +/// Returns ImplicitGemm tensor B extent as Tensor4DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); + default : break; + } + return cutlass::Tensor4DCoord(); +} + +/// Returns ImplicitGemm tensor C extent as Tensor4DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); + default : break; + } + return cutlass::Tensor4DCoord(); +} + +/// Returns ImplicitGemm tensor A size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_a_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor B size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_b_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor C size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_c_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); + default : break; + } + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Strided dgrad helper functions // +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Returns number of CTAs tile M to cover valid MMAs per starting filter postion +CUTLASS_HOST_DEVICE +int strided_dgrad_tile_m_per_filter( + Conv2dProblemSize const &problem_size, + int tile_size_m) { + + // Compute NHW rows in Dx output that needs MMA per starting filter position + int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h; + int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w; + int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter; + + // Number of CTAs tile M to cover valid MMAs per starting filter postion + int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m; + + return tile_m_per_filter; +} + +// Computes starting Dx coord (h, w) for given starting filter postion +CUTLASS_HOST_DEVICE +void strided_dgrad_starting_coords( + Conv2dProblemSize const &problem_size, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int r, int s, + int &start_h, int &start_w) { + + // function locals for remainder by fast divmod + int pad_h_rem_, pad_w_rem_; + + // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; + stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); + int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); + stride_h_divmod.divmod(start_h, r_); + + //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; + stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); + int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); + stride_w_divmod.divmod(start_w, s_); +} + +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/conv3d_problem_size.h b/server/punica_kernels/include/cutlass/cutlass/conv/conv3d_problem_size.h new file mode 100644 index 00000000..f77c894e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/conv3d_problem_size.h @@ -0,0 +1,513 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing convolution problem sizes. + + Conv3dProblem desciption: + activation (NDHWC), + filter (KTRSC), + output (NZPQK), + pading (pad_d, pad_h, pad_w), + stride (stride_d, stride_h, stride_w), + dilation (dilation_d, dilation_h, dilation_w). + + Free functions to map: + Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) + Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) + Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) +*/ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +namespace cutlass { +namespace conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Problem size structure +struct Conv3dProblemSize : public Conv2dProblemSize { + // + // Type definitions + // + + // 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions + using Coord3D = Coord<3>; + + // + // Data members + // + + // Conv3d strictly problem size parameters + int D, T, Z; // input depth, filter depth, output depth + int pad_d; // padding in depth dimension + int stride_d; // stride in depth dimension + int dilation_d; // dilation in depth dimension + + // + // Methods + // +public: + CUTLASS_HOST_DEVICE + Conv3dProblemSize(): + Conv2dProblemSize(), + D(0), T(0), Z(0), + pad_d(0), + stride_d(1), + dilation_d(1) { } + + /// Constructor for default padding, stride, dilation, and split-K + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + int N, + int D, + int H, + int W, + int C, + int Z, + int P, + int Q, + int K, + int T, + int R, + int S, + Mode mode + ): + Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode), + D(D), T(T), Z(Z), + pad_d(T / 2), stride_d(1), dilation_d(1) { } + + /// Constructor + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + int N, + int D, + int H, + int W, + int C, + int K, + int T, + int R, + int S, + int Z, + int P, + int Q, + int pad_d, + int pad_h, + int pad_w, + int stride_d, + int stride_h, + int stride_w, + int dilation_d, + int dilation_h, + int dilation_w, + Mode mode, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + N, H, W, C, K, R, S, P, Q, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w, + mode, split_k_slices, groups), + D(D), T(T), Z(Z), + pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { } + + /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D + // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + cutlass::Tensor5DCoord input_size, // NDHWC + cutlass::Tensor5DCoord filter_size, // KTRSC + Coord3D padding, // pad_d, pad_h, pad_w + Coord3D stride, // stride_d, stride_h, stride_w + Coord3D dilation, // dilation_d, dilation_h, dilation_w + cutlass::Tensor5DCoord output_size, // NZPQK + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, + {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, + {padding[1], padding[1], padding[2], padding[2]}, + {stride[1], stride[2]}, + {dilation[1], dilation[2]}, + {output_size.n(), output_size.h(), output_size.w(), output_size.c()}, + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), Z(output_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { } + + /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D + // *computes* output size and sets Z, P and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + cutlass::Tensor5DCoord input_size, // NDHWC + cutlass::Tensor5DCoord filter_size, // KTRSC + Coord3D padding, // pad_d, pad_h, pad_w + Coord3D stride, // stride_d, stride_h, stride_w + Coord3D dilation, // dilation_d, dilation_h, dilation_w + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, + {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, + {padding[1], padding[1], padding[2], padding[2]}, + {stride[1], stride[2]}, + {dilation[1], dilation[2]}, + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) + { + // set output Z + Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; + } + + /// Constructs convolution problem size from cutlass Tensor5DCoord, Coord3D + // *computes* output size and sets Z, P and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + cutlass::Tensor5DCoord input_size, // NDHWC + cutlass::Tensor5DCoord filter_size, // KTRSC + CUTLASS_STL_NAMESPACE::tuple padding, // Coord3D {pad_d, pad_h, pad_w} & Coord3D {far pad_d, pad_h, pad_w} to calculate o/p/q + Coord3D stride, // stride_d, stride_h, stride_w + Coord3D dilation, // dilation_d, dilation_h, dilation_w + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, + {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, + {CUTLASS_STL_NAMESPACE::get<0>(padding)[1], CUTLASS_STL_NAMESPACE::get<1>(padding)[1], + CUTLASS_STL_NAMESPACE::get<0>(padding)[2], CUTLASS_STL_NAMESPACE::get<1>(padding)[2]}, + {stride[1], stride[2]}, + {dilation[1], dilation[2]}, + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), + pad_d(CUTLASS_STL_NAMESPACE::get<0>(padding)[0]), stride_d(stride[0]), dilation_d(dilation[0]) + { + // set output Z + Z = ((D + pad_d + CUTLASS_STL_NAMESPACE::get<1>(padding)[0] - T * dilation_d) / stride_d) + 1; + } + + /// Equality operator (ignores mode and split_k_slice) + CUTLASS_HOST_DEVICE + bool operator==(Conv3dProblemSize const &conv) const { + return ( + (N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) && + (K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) && + (Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) && + (pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && + (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && + (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) + ); + } + + /// Inequality operator + CUTLASS_HOST_DEVICE + bool operator!=(Conv3dProblemSize const &rhs) const { + return !(*this == rhs); + } + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) { + Conv3dProblemSize tmp(*this); + tmp.mode = mode_; + return tmp; + } + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv3dProblemSize reset_split_k_slices(int split_k_slices_) { + Conv3dProblemSize tmp(*this); + tmp.split_k_slices = split_k_slices_; + return tmp; + } + + /// Returns activation extent as Tensor5DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor5DCoord activation_extent() const { + + return cutlass::Tensor5DCoord ({N, D, H, W, C}); + } + + /// Returns filter extent as Tensor5DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor5DCoord filter_extent() const { + + return cutlass::Tensor5DCoord ({K, T, R, S, C}); + } + + /// Returns output extent as Tensor5DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor5DCoord output_extent() const { + + return cutlass::Tensor5DCoord ({N, Z, P, Q, K}); + } + + /// Returns activation size in number of elements + CUTLASS_HOST_DEVICE + int64_t activation_size() const { + + return (N * D * H * W * C); + } + + /// Returns filter size in number of elements + CUTLASS_HOST_DEVICE + int64_t filter_size() const { + + return (K * T * R * S * C); + } + + /// Returns output size in number of elements + CUTLASS_HOST_DEVICE + int64_t output_size() const { + + return (N * Z * P * Q * K); + } + + /// Returns padding as Coord3D + CUTLASS_HOST_DEVICE + Coord3D padding() const { + + return Coord3D ({pad_d, pad_h, pad_w}); + } + + /// Returns stride as MatrixCoord + CUTLASS_HOST_DEVICE + Coord3D stride() const { + + return Coord3D ({stride_d, stride_h, stride_w}); + } + + /// Returns dilation as MatrixCoord + CUTLASS_HOST_DEVICE + Coord3D dilation() const { + + return Coord3D ({dilation_d, dilation_h, dilation_w}); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ImplicitGemm helper functions // +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Determine the problem size of the implicit GEMM operation +CUTLASS_HOST_DEVICE +cutlass::gemm::GemmCoord implicit_gemm_problem_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + // Compute problem size + switch (conv_operator) { + case Operator::kFprop: + return gemm::GemmCoord( + problem_size.N * problem_size.Z * problem_size.P * problem_size.Q, + problem_size.K, + problem_size.T * problem_size.R * problem_size.S * problem_size.C + ); + case Operator::kDgrad: + return gemm::GemmCoord( + problem_size.N * problem_size.D * problem_size.H * problem_size.W, + problem_size.C, + problem_size.T * problem_size.R * problem_size.S * problem_size.K + ); + case Operator::kWgrad: + return gemm::GemmCoord( + problem_size.K, + problem_size.T * problem_size.R * problem_size.S * problem_size.C, + problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + ); + default: + break; + } + return gemm::GemmCoord(); +} + +// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv3dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int iterations = 0; + int elements_per_split_k_slice = 0; + if (group_mode == GroupMode::kNone) { + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; + + default: + break; + } + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.T * problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } + + return iterations; +} + +//////////////////////////////////////////////////////////////////////////////// +// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) +//////////////////////////////////////////////////////////////////////////////// +/// Returns ImplicitGemm tensor A extent as Tensor5DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); + default : break; + } + return cutlass::Tensor5DCoord(); +} + +/// Returns ImplicitGemm tensor B extent as Tensor5DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); + default : break; + } + return cutlass::Tensor5DCoord(); +} + +/// Returns ImplicitGemm tensor C extent as Tensor5DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); + default : break; + } + return cutlass::Tensor5DCoord(); +} + +/// Returns ImplicitGemm tensor A size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_a_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor B size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_b_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor C size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_c_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); + default : break; + } + return 0; +} + +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/convnd_problem_shape.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/convnd_problem_shape.hpp new file mode 100644 index 00000000..a32389f6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/convnd_problem_shape.hpp @@ -0,0 +1,574 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing convolution problem shapes. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/conv/convolution.h" + +#include "cute/container/array.hpp" + +#if ! defined(__CUDACC_RTC__) +#include +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion. +// All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK) +// Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types. +template < + conv::Operator ConvOp_, + int NumSpatialDimensions +> +struct ConvProblemShape { + // + // Alias types for members + // + static constexpr int RankS = NumSpatialDimensions; + static constexpr int RankT = NumSpatialDimensions + 2; + static constexpr conv::Operator ConvOp = ConvOp_; + using SpatialExtent = cute::array; + using TensorExtent = cute::array; + using TensorStride = cute::array; + using ShapePadding = SpatialExtent; + using TraversalStride = SpatialExtent; + using ShapeDilation = SpatialExtent; + using Corner = SpatialExtent; + + // + // Members + // + cutlass::conv::Mode mode{}; + TensorExtent shape_A{}; + TensorStride stride_A{}; + TensorExtent shape_B{}; + TensorStride stride_B{}; + TensorExtent shape_C{}; + TensorStride stride_C{}; + + // asymmetric padding, both upper and lower padding must be >= 0 + ShapePadding lower_padding{}; + ShapePadding upper_padding{}; + TraversalStride traversal_stride{}; + ShapeDilation dilation{}; + int groups = 1; + + // + // Methods + // + + ConvProblemShape() = default; + + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, // convolution/cross-correlation + TensorExtent shape_act, // [n,d,h,w,c] + TensorStride stride_act, // [n,d,h,w,c] + TensorExtent shape_flt, // [k,t,r,s,c] + TensorStride stride_flt, // [k,t,r,s,c] + ShapePadding lower_padding, // [pad_d, pad_h, pad_w] + ShapePadding upper_padding, // [pad_d, pad_h, pad_w] + TraversalStride tstride, // [stride_d, stride_h, stride_w] + ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] + int groups) + : mode(mode) + , lower_padding(lower_padding) + , upper_padding(upper_padding) + , traversal_stride(tstride) + , dilation(dilation) + , groups(groups) { + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Allow user input of xformed activation stride to support non-packed strides. + ConvProblemShape( + conv::Mode mode, // convolution/cross-correlation + TensorExtent shape_act, // [n,d,h,w,c] + TensorStride stride_act, // [n,d,h,w,c] + TensorExtent shape_flt, // [k,t,r,s,c] + TensorStride stride_flt, // [k,t,r,s,c] + TensorStride stride_xformed_act, // [n,z,p,q,k] + ShapePadding lower_padding, // [pad_d, pad_h, pad_w] + ShapePadding upper_padding, // [pad_d, pad_h, pad_w] + TraversalStride tstride, // [stride_d, stride_h, stride_w] + ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] + int groups) + : mode(mode) + , lower_padding(lower_padding) + , upper_padding(upper_padding) + , traversal_stride(tstride) + , dilation(dilation) + , groups(groups) { + + CUTLASS_ASSERT(stride_act[RankT - 1] == 1); + CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); + CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); + + auto stride_act_packed = packed_stride_right_major(shape_act); + auto stride_flt_packed = packed_stride_right_major(shape_flt); + auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < RankT - 1; ++i) { + CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); + CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); + CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); + } + + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order. + ConvProblemShape( + conv::Mode mode, + TensorExtent shape_act, + TensorExtent shape_flt, + ShapePadding lower_padding, + ShapePadding upper_padding, + TraversalStride tstride, + ShapeDilation dilation, + int groups) + : ConvProblemShape( + mode, + shape_act, + packed_stride_right_major(shape_act), + shape_flt, + packed_stride_right_major(shape_flt), + lower_padding, + upper_padding, + tstride, + dilation, + groups) { + } + +#if ! defined(__CUDACC_RTC__) + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list stride_act_, + std::initializer_list shape_flt_, + std::initializer_list stride_flt_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + + assert(shape_act_.size() == shape_act.size()); + assert(stride_act_.size() == stride_act.size()); + assert(shape_flt_.size() == shape_flt.size()); + assert(stride_flt_.size() == stride_flt.size()); + assert(lower_padding_.size() == lower_padding.size()); + assert(upper_padding_.size() == upper_padding.size()); + assert(traversal_stride_.size() == traversal_stride.size()); + assert(dilation_.size() == dilation.size()); + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Allow user input of xformed activation stride to support non-packed strides. + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list stride_act_, + std::initializer_list shape_flt_, + std::initializer_list stride_flt_, + std::initializer_list stride_xformed_act_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + TensorStride stride_xformed_act{}; + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); + std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + + CUTLASS_ASSERT(stride_act[RankT - 1] == 1); + CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); + CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); + + auto stride_act_packed = packed_stride_right_major(shape_act); + auto stride_flt_packed = packed_stride_right_major(shape_flt); + auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < RankT - 1; ++i) { + CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); + CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); + CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); + } + + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list shape_flt_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + + assert(shape_act_.size() == shape_act.size()); + assert(shape_flt_.size() == shape_flt.size()); + assert(lower_padding_.size() == lower_padding.size()); + assert(upper_padding_.size() == upper_padding.size()); + assert(traversal_stride_.size() == traversal_stride.size()); + assert(dilation_.size() == dilation.size()); + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + stride_act = packed_stride_right_major(shape_act); + stride_flt = packed_stride_right_major(shape_flt); + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } +#endif // not defined(__CUDACC_RTC__) + + // Set shape and stride of tensor A/B/C according to following table: + // | | Fprop | Dgrad | Wgrad | + // | ------ | ------ | ------ | ------| + // | ShapeA | NDHWC | NZPQK | NZPQK | + // | ShapeB | KTRSC | KTRSC | NDHWC | + // | ShapeC | NZPQK | NDHWC | KTRSC | + // + CUTLASS_HOST_DEVICE + constexpr void + set_shape_stride_ABC( + TensorExtent shape_act, + TensorStride stride_act, + TensorExtent shape_flt, + TensorStride stride_flt, + TensorExtent shape_xformed_act, + TensorStride stride_xformed_act) { + + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + shape_A = shape_act; + stride_A = stride_act; + shape_B = shape_flt; + stride_B = stride_flt; + shape_C = shape_xformed_act; + stride_C = stride_xformed_act; + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + shape_A = shape_xformed_act; + stride_A = stride_xformed_act; + shape_B = shape_flt; + stride_B = stride_flt; + shape_C = shape_act; + stride_C = stride_act; + } + else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + shape_A = shape_xformed_act; + stride_A = stride_xformed_act; + shape_B = shape_act; + stride_B = stride_act; + shape_C = shape_flt; + stride_C = stride_flt; + } + } + + // Get problem shape MNK according to following table: + // | | Fprop | Dgrad | Wgrad | + // | ---- | --------- | -------- | -------- | + // | Shape_M | (Q,P,Z,N) | (W,H,D,N) | (K) | + // | Shape_N | (K) | (C) | (C,S,R,T) | + // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | + CUTLASS_HOST_DEVICE + constexpr auto + get_transformed_problem_shape_MNK() const { + using cute::insert; + using cute::make_shape; + using cute::reverse; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kWgrad) { + auto M_xformed = shape_C[0]; + auto N_xformed = reverse(take<1, RankT>(shape_C)); + auto K_xformed = reverse(take<0, RankT - 1>(shape_A)); + + return make_shape(M_xformed, N_xformed, K_xformed); + } + else if constexpr (ConvOp == conv::Operator::kFprop){ + auto M_xformed = reverse(take<0, RankT - 1>(shape_C)); + auto N_xformed = shape_C[RankT - 1]; + auto K_xformed = reverse(take<1, RankT>(shape_B)); + + return make_shape(M_xformed, N_xformed, K_xformed); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + auto M_xformed = reverse(take<0,RankT - 1>(shape_C)); + auto N_xformed = shape_C[RankT - 1]; + // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] + auto K_xformed = insert<0>( + (reverse(take<1,RankT - 1>(shape_B))), + shape_B[0]); + return make_shape(M_xformed, N_xformed, K_xformed); + } + } + + + // Get A extents. + // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) + // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N)) + // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_A() const { + using cute::make_shape; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kDgrad) { + return make_shape( + cute::reverse(take<0, RankT - 1>(shape_A)), + shape_A[RankT - 1]); + } + // For wgrad kernel, we need to linearize NZPQ for tensor A + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_A[RankT - 1], + cute::product(take<0, RankT - 1>(shape_A))); + } + } + + // Get B extents. + // fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) + // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N)) + // dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_B() const { + using cute::make_shape; + using cute::reverse; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop) { + return make_shape( + shape_B[0], + reverse(take<1, RankT>(shape_B))); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_B[RankT - 1], + reverse(take<0, RankT - 1>(shape_B))); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + // shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)] + return make_shape( + shape_B[RankT - 1], + cute::insert<0>( + reverse(take<1, RankT - 1>(shape_B)), + shape_B[0])); + } + } + + // Static method that returns the canonical strides of tensors (layouts are right major and compact) + CUTLASS_HOST_DEVICE + static constexpr TensorStride + packed_stride_right_major(TensorExtent const& extents) { + TensorStride strides{}; + strides[RankT-1] = 1; + cute::for_each(cute::make_rseq{}, [&](auto i) { + strides[i] = extents[i+1] * strides[i+1]; + }); + return strides; + } + + // Static method that returns the packed logical size of any TensorExtent + CUTLASS_HOST_DEVICE + static constexpr size_t + size(TensorExtent const& extents) { + size_t size = 1; + cute::for_each(cute::make_seq{}, [&](auto i) { + size *= extents[i]; + }); + return size; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_A() const { + return shape_A[0] * stride_A[0]; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_B() const { + return shape_B[0] * stride_B[0]; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_C() const { + return shape_C[0] * stride_C[0]; + } + + // Equality operator + CUTLASS_HOST_DEVICE + bool operator==(ConvProblemShape const& rhs) const { + using cute::for_each; + using cute::make_seq; + + bool is_equal = true; + + // Compare all tensor extents + for_each(make_seq{}, [&](auto i) { + is_equal = is_equal + && (shape_A[i] == rhs.shape_A[i]) + && (shape_B[i] == rhs.shape_B[i]); + }); + + // Compare all spatial extents + for_each(make_seq{}, [&](auto i) { + is_equal = is_equal + && (lower_padding[i] == rhs.lower_padding[i]) + && (upper_padding[i] == rhs.upper_padding[i]) + && (traversal_stride[i] == rhs.traversal_stride[i]) + && (dilation[i] == rhs.dilation[i]); + }); + + return is_equal; + } + + /// Inequality operator + CUTLASS_HOST_DEVICE + bool operator!=(ConvProblemShape const &rhs) const { + return !(*this == rhs); + } + +private: + CUTLASS_HOST_DEVICE + constexpr auto + calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) { + TensorExtent shape_xformed_act{}; + // calculate n,z,p,q,k. + // a helper lambda to compute a single spatial extent of the nzpqk tensor + auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { + return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; + }; + + shape_xformed_act[0] = shape_act[0]; // Activation N extent + cute::for_each(cute::make_seq{}, [&](auto i) { + shape_xformed_act[i+1] = nzpqk_extent( + shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]); + }); + shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent + + TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act); + + return cute::make_tuple(shape_xformed_act, stride_xformed_act); + } +}; + +template< + conv::Operator ConvOp, + int SpatialDim +> +void print(ConvProblemShape const& problem) { + printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n", + SpatialDim, int(ConvOp)); + printf("\tTensorA: "); + cute::print(problem.shape_A); printf(":"); + cute::print(problem.stride_A); printf("\n"); + printf("\tTensorB: "); + cute::print(problem.shape_B); printf(":"); + cute::print(problem.stride_B); printf("\n"); + printf("\tTensorC: "); + cute::print(problem.shape_C); printf(":"); + cute::print(problem.stride_C); printf("\n"); + printf("\tLower padding: "); print(problem.lower_padding); printf("\n"); + printf("\tUpper padding: "); print(problem.upper_padding); printf("\n"); + printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n"); + printf("\tDilation: "); print(problem.dilation); printf("\n"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/convolution.h b/server/punica_kernels/include/cutlass/cutlass/conv/convolution.h new file mode 100644 index 00000000..a61f573e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/convolution.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + +This file contains definitions and utility functions for describing convolution problem sizes in terms of +activation (NHWC), filter (KRSC), output (NPQK), padding (pad_h, pad_w), stride (stride_h, stride_w), and +dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map CUTLASS's implicit gemm +tensor extents, sizes, and data types to that of the convolution's extents, sizes, and data types. + + * Mapping convolutions to Gemm computation * + +Cutlass implements convolutions with the Implicit Gemm algorithm. This algorithm performs a gemm +(general matrix-matrix multiply) on the convolution tensors Activation, Filter, and Output. +The underlying gemm operation follows the standard gemm definition: + + C = A * B + C + + A and B are input matrices + C is source and output matrix + + +For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped +to convolution tensors Activation, Filter and Output as described in the table below. + + ___________________________________________________________________________ + ConvolutionalOperator | A | B | C + ___________________________________________________________________________ + | | | | | + | Fprop | Activation | Filter | Output | + | Dgrad | Output | Filter | Activation | + | Wgrad | Output | Activation | Filter | + ___________________________________________________________________________ + +In convolution codebase, DO NOT mix using (A, B, C) with (Activation, Filter, Output). + +For example, it's confusing and error prone to document a convolution class or function +as operating on "A, B, Output." Instead, use the mapping functions below, +and adhere to using either A, B, C or Activation, Filter, Output. + +Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap +Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap +*/ + +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm_enumerated_types.h" +#include "cutlass/matrix_coord.h" + +namespace cutlass { +namespace conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Convolutional operator +enum class Operator { + kFprop, + kDgrad, + kWgrad +}; + +/// Distinguishes convolution from cross correlation +enum class Mode { + kCrossCorrelation, + kConvolution +}; + +/// Selects among several implementation variants trading off performance with simplicity +enum class IteratorAlgorithm { + kAnalytic, ///< functionally correct in all cases but lower performance + kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad + kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) + kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) + kFixedStrideDilation ///< Optimized for fixed stride and dilation +}; + +/// Distinguishes among partial specializations that accelerate certain problems where convolution +/// stride is unit. +enum class StrideSupport { + kStrided, ///< arbitrary convolution stride + kUnity, ///< unit convolution stride + kFixed ///< fixed convolution stride +}; + +/// Identifies split-K mode +enum class SplitKMode { + kNone, + kSerial, + kParallel +}; + +/// Identifies group mode +enum class GroupMode { + kNone, + kSingleGroup, ///< One CTA calculates one group or less + kMultipleGroup, ///< One CTA calculates multiple groups + kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Shape of a tensor +template < + int N = 1, + int H = 1, + int W = 1, + int C = 1 +> +struct TensorNHWCShape { + static int const kN = N; + static int const kH = H; + static int const kW = W; + static int const kC = C; + + static int const kHW = H * W; + static int const kNHW = N * kHW; + static int const kNHWC = N * H * W * C; + + static int const kCount = kNHWC; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<4> toCoord() { + return make_Coord(kN, kH, kW, kC); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Shape of a conv2d stride, which controls how the filter convolves around the input volume +template < + /// Stride in horizontal direction + int u = 1, + /// Stride in vertical direction + int v = 1 +> +struct Stride2D { + static int const kU = u; + static int const kV = v; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<2> toCoord() { + return make_Coord(kU, kV); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/device/conv_universal_adapter.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/device/conv_universal_adapter.hpp new file mode 100644 index 00000000..69cfbaba --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/device/conv_universal_adapter.hpp @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +// common +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/arch/mma.h" +#include "cutlass/trace.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/device_kernel.h" + +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + ConvUniversalAdapter is a stateful, reusable handle built around a kernel + of type cutlass::conv::kernel::ConvUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, static methods + are exposed that bypass the stateful methods or args->params lowering. +*/ +template +class ConvUniversalAdapter +{ +public: + using ConvKernel = ConvKernel_; + using TileShape = typename ConvKernel::TileShape; + using ElementA = typename ConvKernel::ElementA; + using ElementB = typename ConvKernel::ElementB; + using ElementC = typename ConvKernel::ElementC; + using ElementD = typename ConvKernel::ElementD; + using ElementAccumulator = typename ConvKernel::TiledMma::ValTypeC; + using DispatchPolicy = typename ConvKernel::DispatchPolicy; + using CollectiveMainloop = typename ConvKernel::CollectiveMainloop; + using CollectiveEpilogue = typename ConvKernel::CollectiveEpilogue; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + // Tease out meta-information about the conv algorithm + static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = ConvKernel::NumSpatialDimensions; + + // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! + using OperatorClass = cute::conditional_t< + (cute::size(typename ConvKernel::TiledMma::AtomThrID{}) > 1), + cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; + + using ArchTag = typename ConvKernel::ArchTag; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = ConvKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename ConvKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); + static int constexpr kAlignmentB = detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); + static int constexpr kAlignmentC = detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + /// Argument structure: User API + using Arguments = typename ConvKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename ConvKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Determines whether the conv can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (ConvKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += ConvKernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = ConvKernel::to_underlying_arguments(args, workspace); + return ConvKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return ConvKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("ConvUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = ConvKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + ConvKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes conv state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("ConvUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = ConvKernel::get_workspace_size(args); + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + // Initialize the Params structure + params_ = ConvKernel::to_underlying_arguments(args, workspace); + + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // account for dynamic smem capacity if needed + int smem_size = ConvKernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("ConvUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = ConvKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling ConvKernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("ConvUniversal::run()"); + dim3 const block = ConvKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = ConvKernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(ConvKernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); + void* kernel_params[] = {¶ms}; + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + launch_result = cuda_adapter->launch( + grid, cluster, block, smem_size, stream, kernel_params, 0 + ); + } + else { + return Status::kErrorInternal; + } + } + else { + + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + + } + } + else { + launch_result = Status::kSuccess; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + device_kernel<<>>(params); + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(args, workspace, stream, cuda_adapter); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/device/direct_convolution.h b/server/punica_kernels/include/cutlass/cutlass/conv/device/direct_convolution.h new file mode 100644 index 00000000..5c259d46 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/device/direct_convolution.h @@ -0,0 +1,268 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Depthwise Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConvolution { +public: + + using UnderlyingKernel = DirectConvolutionKernel_; + + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename UnderlyingKernel::Arguments; + + using ReorderKernel = typename UnderlyingKernel::ReorderKernel; + + private: + + /// Kernel parameters object + typename UnderlyingKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + DirectConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + if (kGroupMode != conv::GroupMode::kDepthwise) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K != args.problem_size.groups && + args.problem_size.C != args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.K % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + // initialize the params structure from the arguments + params_ = typename UnderlyingKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.ptr_reordered_B = args.ref_reordered_B.data();; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // Launch reorder kernel + if (params_.ptr_reordered_B != nullptr) { + dim3 grid = ReorderKernel::get_grid_shape(params_); + dim3 block = ReorderKernel::get_block_shape(); + + cutlass::Kernel<<>>(params_); + } + + // Launch main kernel + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + // Dynamic SMEM size based on input params. + int smem_size = int(params_.get_smem_size()); + + // Make sure we can use that much shared memory. + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status != cudaSuccess) + return Status::kErrorInternal; + + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + + int get_smem_size() { return int(params_.get_smem_size()); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/device/implicit_gemm_convolution.h b/server/punica_kernels/include/cutlass/cutlass/conv/device/implicit_gemm_convolution.h new file mode 100644 index 00000000..dfb146f2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/device/implicit_gemm_convolution.h @@ -0,0 +1,362 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Implicit GEMM Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ImplicitGemmConvolution { +public: + + using UnderlyingKernel = ImplicitGemmKernel_; + + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename UnderlyingKernel::Arguments; + +private: + + /// Kernel parameters object + typename UnderlyingKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + ImplicitGemmConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + // check group conv constraint + if (args.problem_size.groups != 1) { + if (kGroupMode == conv::GroupMode::kNone) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K % args.problem_size.groups || + args.problem_size.C % args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + // split-k is not supported + if (args.problem_size.split_k_slices != 1) { + return Status::kErrorInvalidProblem; + } + + int k_per_group = args.problem_size.K / args.problem_size.groups; + // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group + if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { + return Status::kErrorInvalidProblem; + } + // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups + if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { + return Status::kErrorInvalidProblem; + } + + // current optimized iterator algo only supports SingleGroup mode + if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && + kGroupMode != conv::GroupMode::kSingleGroup) { + return Status::kErrorInvalidProblem; + } + } + + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.K % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } + + // check for unsupported problem sizes for strided dgrad implementation + if (kConvolutionalOperator == conv::Operator::kDgrad && + kStrideSupport == conv::StrideSupport::kStrided) { + + // split-k (serial or parallel) is not supported for strided dgrad + if(args.problem_size.split_k_slices > 1) { + return Status::kErrorNotSupported; + } + + // dilation > {1x1} is not supported for strided dgrad + if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { + return Status::kErrorNotSupported; + } + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + if(args.split_k_mode == SplitKMode::kParallel) { + + // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. + // The user needs to call a reduction operator to optain the final output tensor + workspace_bytes = + sizeof(ElementAccumulator) * + size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * + size_t(grid_tiled_shape.k()); + } + + else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { + + // Split-K serial: The user workspace is used to store semaphore and serialize writing the + // final reduced output to user's output tensor + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + if (args.problem_size.split_k_slices > 1) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); + + if (status != cudaSuccess) { + return Status::kErrorInternal; + } + } + + // initialize the params structure from the arguments + params_ = typename UnderlyingKernel::Params( + args, + static_cast(workspace) + ); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + cutlass::Status launch_result = cutlass::Status::kSuccess ; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + void* kernel_params[] = {¶ms_}; + launch_result = cuda_adapter->launch( + grid, dim3(1,1,1), block, smem_size, stream, kernel_params, 0 + ); + } + else { + launch_result = Status::kErrorInternal; + } + } + else { + cutlass::Kernel<<>>(params_); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return run(stream, cuda_adapter); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/server/punica_kernels/include/cutlass/cutlass/conv/device/implicit_gemm_convolution_fusion.h new file mode 100644 index 00000000..7b15520d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/device/implicit_gemm_convolution_fusion.h @@ -0,0 +1,268 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ImplicitGemmConvolutionFusion { +public: + + using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_; + + using ElementA = typename ImplicitGemmFusionKernel::ElementA; + using LayoutA = typename ImplicitGemmFusionKernel::LayoutA; + using ElementB = typename ImplicitGemmFusionKernel::ElementB; + using LayoutB = typename ImplicitGemmFusionKernel::LayoutB; + +// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias; +// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias; + + using ElementC = typename ImplicitGemmFusionKernel::ElementC; + using LayoutC = typename ImplicitGemmFusionKernel::LayoutC; + using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator; + using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute; + using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass; + using ArchTag = typename ImplicitGemmFusionKernel::ArchTag; + using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape; + using WarpShape = typename ImplicitGemmFusionKernel::WarpShape; + using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape; + using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp; + static int const kStages = ImplicitGemmFusionKernel::kStages; + static int const kConvDim = ImplicitGemmFusionKernel::kConvDim; + using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator; + using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator; + using MathOperator = typename ImplicitGemmFusionKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename ImplicitGemmFusionKernel::Arguments; + +private: + + /// Kernel parameters object + typename ImplicitGemmFusionKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + ImplicitGemmConvolutionFusion() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + if(args.split_k_mode == SplitKMode::kParallel) { + + // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. + // The user needs to call a reduction operator to optain the final output tensor + workspace_bytes = + sizeof(ElementAccumulator) * + size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * + size_t(grid_tiled_shape.k()); + } + + else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { + + // Split-K serial: The user workspace is used to store semaphore and serialize writing the + // final reduced output to user's output tensor + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + if (args.problem_size.split_k_slices > 1) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); + + if (status != cudaSuccess) { + return Status::kErrorInternal; + } + } + + // initialize the params structure from the arguments + params_ = typename ImplicitGemmFusionKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes Impicit GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_scale = args.ref_A_scale.data(); + params_.ptr_bias = args.ref_A_bias.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); + + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/dispatch_policy.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/dispatch_policy.hpp new file mode 100644 index 00000000..32e3e417 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/dispatch_policy.hpp @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convolution.h" +#include "cutlass/arch/arch.h" + +#include "cute/layout.hpp" +#include "cute/numeric/integral_constant.hpp" + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv { + +////////////////////////////////////////////////////////////////////////////// + +// +// Policies for categorical dispatch of mainloop against kernel grid schedules +// +struct KernelImplicitTmaWarpSpecializedSm90 { }; +struct KernelImplicitTmaWarpSpecializedSm90Cooperative { }; +struct KernelImplicitTmaWarpSpecializedSm90Pingpong { }; + +// +// Collective Mainloop Policies +// + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA +// for fprop +template< + conv::Operator ConvOp_, + int Stages_, + int NumSpatialDimensions_, + class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>>, + class KernelSchedule = KernelImplicitTmaWarpSpecializedSm90, + int PipelineAsyncMmaStages_ = 1 +> +struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { + static constexpr int Stages = Stages_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; + static constexpr Operator ConvOp = ConvOp_; + static constexpr int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; + + static_assert(NumSpatialDimensions >= 1); + static_assert(! (cute::is_same_v || + cute::is_same_v), + "Persistent schedules not support for conv yet."); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/conv_universal.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/conv_universal.hpp new file mode 100644 index 00000000..9d98dc9d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/conv_universal.hpp @@ -0,0 +1,63 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device CONV kernel type that treats CONV as + * a composition of a collective mainloop and a collective epilogue. +**/ +template < + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ = void, + class Enable = void +> +class ConvUniversal { + static_assert(cutlass::detail::dependent_false, + "Could not find a valid specialization at the kernel layer to dispatch against."); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d.h new file mode 100644 index 00000000..51310304 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d.h @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/conv/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" +#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h" +#include "cutlass/conv/threadblock/implicit_gemm_multistage.h" +#include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h" +#include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogue { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogue< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp +> { + + using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ArchTag, + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastSimt { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimt< + Shape, + WarpMmaSimt, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastTensorOp { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastTensorOp< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + > { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithReductionTensorOp { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithReductionTensorOp< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + > { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + >::Epilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Defaults for strided Dgrad +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogueStridedDgrad { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogueStridedDgrad< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp +> { + + using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_dgrad.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_dgrad.h new file mode 100644 index 00000000..8eb97951 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_dgrad.h @@ -0,0 +1,1927 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dDgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dDgrad; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided and +// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided +// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity +// 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for optimized IteratorAlgorithm Dgrad Unity Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided and +// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided +// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Unity +// 2 stage pipeline +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop.h new file mode 100644 index 00000000..7eb4e1bd --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop.h @@ -0,0 +1,1989 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kFixedChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFewChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kFewChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +// multistage pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + layout::TensorNCxHWx, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + layout::TensorCxRSKx, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_fusion.h new file mode 100644 index 00000000..97bda7b9 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution + definitions that combine threadblock-scoped matrix multiply-add with the + appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for fused batch norm and Conv2dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv2dFpropFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h new file mode 100644 index 00000000..b0e0ae65 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Defines a default configuration for convolution with absolute maximum calculation. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_absmax.h" +#include "cutlass/epilogue/threadblock/epilogue_with_absmax.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv2dFpropWithAbsMax { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithAbsMax< + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementAuxOutput, + ElementC, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithAbsMax< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h new file mode 100644 index 00000000..afc66986 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -0,0 +1,220 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv2dFpropWithBroadcast { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFpropWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h new file mode 100644 index 00000000..8d0c1d47 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -0,0 +1,130 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" +#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename EpilogueReductionOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv2dFpropWithReduction { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + EpilogueOutputOp, + EpilogueReductionOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_group_fprop.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_group_fprop.h new file mode 100644 index 00000000..07bff918 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dGroupFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dGroupFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline that supports all GroupMode. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and +/// 2 stage pipeline that supports all GroupMode. + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage +/// pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and +/// 2 stage pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_wgrad.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_wgrad.h new file mode 100644 index 00000000..e0f09fd2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_wgrad.h @@ -0,0 +1,1011 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dWgrad; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h new file mode 100644 index 00000000..8fe8713d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv2dWgradFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dWgradFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< + cutlass::MatrixShape<1, WarpShape::kN>, + ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmWgradFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + IteratorScaleBias, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dWgradFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< + cutlass::MatrixShape<1, WarpShape::kN>, + ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmWgradFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + IteratorScaleBias, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_dgrad.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_dgrad.h new file mode 100644 index 00000000..a52f24cb --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_dgrad.h @@ -0,0 +1,735 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dDgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv3dDgrad; + +/// Defines a kernel for Conv3dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // ThreadMapB, + // StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // ThreadMapB, + // StrideSupport::kUnity + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop.h new file mode 100644 index 00000000..0e5b09db --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop.h @@ -0,0 +1,944 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" + + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv3dFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Analytic Iterator Algorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop_fusion.h new file mode 100644 index 00000000..98c930a2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution + definitions that combine threadblock-scoped matrix multiply-add with the + appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for fused batch norm and Conv3dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv3dFpropFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h new file mode 100644 index 00000000..e079bf3d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv3dFpropWithBroadcast { + + using ImplicitGemmBase = typename DefaultConv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv3dFpropWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultConv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_wgrad.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_wgrad.h new file mode 100644 index 00000000..69b444f3 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_conv3d_wgrad.h @@ -0,0 +1,936 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv3dWgrad; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_depthwise_fprop.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_depthwise_fprop.h new file mode 100644 index 00000000..9ba28362 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/default_depthwise_fprop.h @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" +#include "cutlass/conv/kernel/direct_convolution.h" + +#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" + +// Direct Conv Related Header files +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h" + +#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for DepthwiseFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value +> struct DefaultDepthwiseFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for DepthwiseFprop with direct convolution algorithm +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + // MatrixShape + typename StrideShape = cutlass::MatrixShape<-1, -1>, + // MatrixShape< Height, Width> + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDepthwiseDirect2dConvFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, // cutlass::arch::OpMultiplyAdd + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize< + ThreadblockShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + sizeof_bits::value, + 2, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + cutlass::conv::GroupMode::kDepthwise + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB +> { + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB +> { + + + + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed"); + static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed"); + + // Activations loaded by threadblock + static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ActivationShape = + cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >; + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape, + DilationShape, + ActivationShape>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + StrideShape, + DilationShape, + ActivationShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue, + IteratorAlgorithm::kFixedStrideDilation + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/direct_convolution.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/direct_convolution.h new file mode 100644 index 00000000..74c79c41 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/direct_convolution.h @@ -0,0 +1,505 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multi-staged Depthwise Convolution kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template > ///! OutputShape per ThreadBlock +struct DirectConvolutionParams { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + using ConvProblemSize = ConvProblemSize_; + using Arguments = Arguments_; + using ConvOutputIteratorParameter = ConvOutputIteratorParameter_; + + using ThreadblockShape = typename Mma::Shape; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static conv::GroupMode const kGroupMode = GroupMode_; + static int const kStages = Mma::kStages; + + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + int smem_size_; + + int gemm_k_iterations; + int gemm_k_iterations_per_channel; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Mma::IteratorB::Element *ptr_reordered_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + int split_k_slices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {} + + /// + CUTLASS_HOST_DEVICE + DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr) + : problem_size(args.problem_size), + implicit_gemm_problem_size( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())), + ptr_B(args.ref_B.data()), + ptr_reordered_B(args.ref_reordered_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + split_k_slices(args.problem_size.split_k_slices) { + gemm_k_iterations = + depthwise_gemm_k_iterations(kConvolutionalOperator, + ThreadblockShape::kK, + args.problem_size, + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + + // Dynamic SMEM usage because stride and dilation are runtime params. + smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size); + } + + CUTLASS_HOST_DEVICE + int get_smem_size() { + // Dynamic Smem Size + return smem_size_; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ReorderKernel { + using Params = Params_; + using ElementB = ElementB_; + + union SharedStorage {}; + + static unsigned int const kReorderKernelThreadPerCTA = 128; + + CUTLASS_HOST_DEVICE + ReorderKernel() {} + + CUTLASS_HOST_DEVICE + static dim3 get_grid_shape(Params const ¶ms) { + return dim3{static_cast( + (params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) / + kReorderKernelThreadPerCTA), + 1, + 1}; + } + + CUTLASS_HOST_DEVICE + static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; } + + CUTLASS_HOST_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + int64_t m = static_cast(params.problem_size.groups); + int64_t n = static_cast(params.problem_size.filter_size() / params.problem_size.K); + const ElementB *src_with_type = static_cast(params.ptr_B); + ElementB *dst_with_type = static_cast(params.ptr_reordered_B); + + int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x; + int64_t index_m = linear_index / n; + int64_t index_n = linear_index % n; + int64_t new_linear_index = index_m + index_n * m; + + if (linear_index < m * n) { + dst_with_type[new_linear_index] = src_with_type[linear_index]; + } + return; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode + typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> +> +struct DirectConvolution { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = GroupMode_; + + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefB ref_reordered_B; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + TensorRefB const & ref_reordered_B = nullptr, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + ref_reordered_B(ref_reordered_B), + split_k_mode(split_k_mode) + { + + } + + }; + + using Params = + typename cutlass::conv::kernel::DirectConvolutionParams; + + using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if threadblock is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + int iterator_column_offset = 0; + int filter_row_offset = 0; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode == GroupMode::kDepthwise) { + iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + iterator_column_offset + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_reordered_B, + thread_idx, + MatrixCoord( + filter_row_offset, + iterator_column_offset + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + + // Compute threadblock-scoped matrix multiply-add + // Epilogue is fused in the mainloop + mma(params.gemm_k_iterations, + accumulators, + iterator_A, + params.iterator_A, + iterator_B, + params.iterator_B, + accumulators, + epilogue, + output_op, + iterator_D, + iterator_C, + params.split_k_slices); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution.h new file mode 100644 index 00000000..f65bf259 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -0,0 +1,456 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode +> +struct ImplicitGemmConvolution { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = GroupMode_; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + int gemm_k_iterations_per_channel; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations = implicit_gemm_k_iterations( + kConvolutionalOperator, + ThreadblockShape::kK, + args.problem_size, + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode != GroupMode::kDepthwise) { + int k_per_group = params.problem_size.K / params.problem_size.groups; + int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group; + int channels_per_group = params.problem_size.C / params.problem_size.groups; + iterator_A_column_offset += group_idx * channels_per_group; + } else { + iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + iterator_A_column_offset + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Run efficient epilogue + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h new file mode 100644 index 00000000..32821f9c --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined fused activation's scale+bias+relu and Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionFusion { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + + using ElementScaleBias = typename Mma::IteratorScaleBias::Element; + using LayoutScaleBias = typename Mma::IteratorScaleBias::Layout; + + using ElementC = typename EpilogueOutputOp::ElementOutput; + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefScaleBias = typename Mma::IteratorScaleBias::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefScaleBias ref_scale; + TensorRefScaleBias ref_bias; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefScaleBias const & ref_scale, + TensorRefScaleBias const & ref_bias, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_bias(ref_bias), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + gemm::GemmCoord implicit_gemm_problem_size{}; + int swizzle_log_tile{0}; + int gemm_k_iterations{0}; + typename Mma::IteratorA::Params iterator_A{}; + typename Mma::IteratorA::Element const *ptr_A = nullptr; + typename Mma::IteratorB::Params iterator_B{}; + typename Mma::IteratorB::Element const *ptr_B = nullptr; + typename Mma::IteratorScaleBias::Params iterator_scale_bias{}; + typename Mma::IteratorScaleBias::Element const *ptr_scale = nullptr; + typename Mma::IteratorScaleBias::Element const *ptr_bias = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_C {}; + typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_D {}; + typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; + typename EpilogueOutputOp::Params output_op {}; + int *semaphore = nullptr; + SplitKMode split_k_mode {}; + + // + // Methods + // + Params() = default; + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_scale_bias(args.problem_size, args.ref_scale.layout()), + ptr_scale(args.ref_scale.data()), + ptr_bias(args.ref_bias.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionFusion() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Construct iterators to A scale/bias vector + typename Mma::IteratorScaleBias iterator_scale_bias( + params.iterator_scale_bias, + params.problem_size, + params.ptr_scale, + params.ptr_bias, + thread_idx, + MatrixCoord( + 0, (kConvolutionalOperator == conv::Operator::kFprop) ? + (threadblock_tile_idx.k() * Mma::Shape::kK) : + // Wgrad + (threadblock_tile_idx.n() * Mma::Shape::kN) + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, + iterator_B, iterator_scale_bias, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Run efficient epilogue + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h new file mode 100644 index 00000000..3409f173 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -0,0 +1,492 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionStridedDgrad { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // Strided dgrad uses a specialized threadblock swizzle for functionality and performance + static_assert((platform::is_same::value) || + (platform::is_same>::value) || + (platform::is_same>::value) || + (platform::is_same>::value), + "Needs ThreadblockSwizzle type specialized for strided dgrad"); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size{}; + TensorRefA ref_A{}; + TensorRefB ref_B{}; + TensorRefC ref_C{}; + TensorRefC ref_D{}; + typename EpilogueOutputOp::Params output_op{}; + SplitKMode split_k_mode{}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + int swizzle_log_tile{0}; + FastDivmod stride_h_divmod{}; + FastDivmod stride_w_divmod{}; + int gemm_k_iterations{0}; + typename Mma::IteratorA::Params iterator_A{}; + typename Mma::IteratorA::Element const *ptr_A = nullptr; + typename Mma::IteratorB::Params iterator_B{}; + typename Mma::IteratorB::Element const *ptr_B = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_C{}; + typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_D{}; + typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; + typename EpilogueOutputOp::Params output_op {}; + int *semaphore = nullptr; + SplitKMode split_k_mode {}; + + // + // Methods + // + Params() = default; + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + stride_h_divmod(args.problem_size.stride_h), + stride_w_divmod(args.problem_size.stride_w), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size, ThreadblockShape::kM), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size, ThreadblockShape::kM), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionStridedDgrad() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Compute starting filter position for strided dgrad + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(params.problem_size, + ThreadblockShape::kM); + int filter_tile_m = (threadblock_tile_idx.m() / tile_m_per_filter); + + + // The subsequent fast_divmod() operations are equivalent to the following logical computation: + // + // int start_r = filter_tile_m / (params.problem_size.stride_w); + // int start_s = filter_tile_m % (params.problem_size.stride_w); + + int start_r, start_s; + params.stride_w_divmod(start_r, start_s, filter_tile_m); + + int filter_r = start_r; + int filter_s = start_s; + + if (params.problem_size.mode == Mode::kConvolution) { + filter_r = (params.problem_size.R - 1 - filter_r); + filter_s = (params.problem_size.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + params.problem_size, + params.stride_h_divmod, params.stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) { + return; + } + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA + if (start_r < params.problem_size.R && start_s < params.problem_size.S) { + // Scale gemm_k_iterations for strided dgrad + int gemm_k_iterations = (params.gemm_k_iterations / (params.problem_size.R * params.problem_size.S) + ) * params.problem_size.num_gemm_k_filter_positions(start_r, start_s); + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + params.stride_h_divmod, params.stride_w_divmod, + start_r, start_s, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + start_r, start_s, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + params.stride_h_divmod, params.stride_w_divmod, + start_r, start_s, + threadblock_offset + ); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + if (output_op.is_source_needed()) + { + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + params.stride_h_divmod, params.stride_w_divmod, + start_r, start_s, + threadblock_offset); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + } + + // Run epilogue with addend source iterator + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + else + { + // Run epilogue without addend source iterator + epilogue(output_op, iterator_D, accumulators); + } + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h new file mode 100644 index 00000000..529808fe --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h @@ -0,0 +1,494 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Convolution kernel with an epilogue that computes the absolute maximum value of the output + and a pre-activation-function auxiliary output. The auxiliary output is also (optionally) + stored to global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionWithAbsMax { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + using TensorRefAux = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + TensorRefC ref_Aux; + + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + void * ptr_Vector; + + typename LayoutC::Stride::Index ldr; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + TensorRefAux const & ref_Aux, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial, + void * ptr_Vector = nullptr, + typename LayoutC::Stride::Index ldr = 0 + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + ref_Aux(ref_Aux), + output_op(output_op), + split_k_mode(split_k_mode), + ptr_Vector(ptr_Vector), + ldr(ldr) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename Epilogue::AuxOutputTileIterator::Params iterator_Aux; + typename Epilogue::AuxOutputTileIterator::Element *ptr_Aux; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + gemm_k_iterations(0), + ptr_Vector(nullptr), + ldr(0) + { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + iterator_Aux(ConvOutputIteratorParameter::layout(args.ref_Aux)), + ptr_Aux(args.ref_Aux.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr) + + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionWithAbsMax() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to auxiliary tensor. + typename Epilogue::AuxOutputTileIterator iterator_Aux( + params.iterator_Aux, + params.ptr_Aux, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; + } + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + iterator_Aux, + ConvOutputIteratorParameter::extent(params.problem_size), + threadblock_offset); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h new file mode 100644 index 00000000..d79e6ef2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -0,0 +1,499 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionWithFusedEpilogue { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + void * ptr_Vector; + void * ptr_Tensor; + + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial, + void * ptr_Vector = nullptr, + void * ptr_Tensor = nullptr, + typename LayoutC::Stride::Index ldr = 0, + typename LayoutC::Stride::Index ldt = 0 + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + ldr(ldr), + ldt(ldt) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + typename Epilogue::TensorTileIterator::Params params_Tensor; + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + void * ptr_Tensor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + gemm_k_iterations(0), + ptr_Vector(nullptr), + ldr(0), + ptr_Tensor(nullptr) + { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + params_Tensor(args.ldt), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor) + + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionWithFusedEpilogue() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + typename Epilogue::ElementTensor *ptr_Tensor = + static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Tensor, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; + } + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + tensor_iterator, + ConvOutputIteratorParameter::extent(params.problem_size), + threadblock_offset); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp new file mode 100644 index 00000000..c6996f15 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -0,0 +1,391 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/cluster_sm90.hpp" + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag +> +class ConvUniversal< + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v, + "TMA warp-specialized kernel does not support specializing the tile scheduler."); + using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Kernel level shared memory storage + struct SharedStorage { + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Host facing host arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + MainloopParams mainloop; + EpilogueParams epilogue; + }; + + // + // Methods + // + + // Map user facing arguments to device facing params + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto mainloop_params = CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace); + auto problem_shape_MNKL = append<4>(mainloop_params.problem_shape, Int<1>{}); + + return { + mainloop_params, + CollectiveEpilogue::to_underlying_arguments(problem_shape_MNKL, args.epilogue, workspace) + }; + } + + // Given arguemnts, returns true if the kernel can successfully compute upon them. False otherwise. + static bool + can_implement(Arguments const& args) { + bool implementable = true; + implementable &= CollectiveMainloop::can_implement(args.mainloop.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.mainloop.problem_shape.get_transformed_problem_shape_MNK(), args.epilogue); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // The CONV mainloop params problem shape will be the cute::Shape<> rank-3 MNK tuple we want for grid planning + // Although conv problems do not have an L mode, we add it here to comply with the scheduler API + auto linear_problem_shape_MNKL = make_shape( + size<0>(params.mainloop.problem_shape), // M mode is linearized. + shape<1>(params.mainloop.problem_shape), + shape<2>(params.mainloop.problem_shape), + Int<1>{}); + + return cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( + linear_problem_shape_MNKL, TileShape{}, ClusterShape{}); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_idx = canonical_warp_idx_sync(); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // Separate out problem shape for convenience + auto M = get<0>(params.mainloop.problem_shape); + auto N = get<1>(params.mainloop.problem_shape); + auto K = get<2>(params.mainloop.problem_shape); + // output strides are coalesced so we linearize the output shape to match the shape/stride profiles + auto linear_problem_shape_MNKL = make_shape(size(M), N, K, Int<1>{}); + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mk = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M, size(K))); + Tensor mB_nk = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N, K)); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto cta_tile_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // Make tiled views, defer the slice + Tensor gA_mk = local_tile(mA_mk, cta_tile_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor gB_nk = local_tile(mB_nk, cta_tile_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mk)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nk)); + // The output shape M is linearized so the output coord M here should also be linearized. + auto output_tile_coord = make_coord(int(blockIdx.x), n_coord, _, Int<0>{}); + + // Slice with m_coord and n_coord + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_count = size<2>(gA); + + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(cta_tile_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(cta_tile_shape); + + // Make sure pipeline init is visible to all producers and consumer CTAs in cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + if (warp_group_role == WarpGroupRole::Producer) { + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, k_tile_count, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + linear_problem_shape_MNKL, + cta_tile_shape, + output_tile_coord, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + // Update starting load pipeline state for the pipeline drain + epi_load_pipe_producer_state.advance(c_tile_count); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(cta_tile_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + linear_problem_shape_MNKL, + cta_tile_shape, + output_tile_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::kernel diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/thread/depthwise_mma.h b/server/punica_kernels/include/cutlass/cutlass/conv/thread/depthwise_mma.h new file mode 100644 index 00000000..37ece792 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/thread/depthwise_mma.h @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for depthwise convolution +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/thread/mma.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// MMA operation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Number of threads participating + int kThreads_, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Inner product operator + typename Operator +> +struct ElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// General implementation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_> +struct ElementwiseInnerProduct { + using Shape = Shape_; + using Operator = arch::OpMultiplyAdd; + using ElementC = ElementC_; + + CUTLASS_HOST_DEVICE + void operator()(Array &d, + Array const &a, + Array const &b, + Array const &c) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Shape::kN; ++i) { + d[i] = a[i] * b[i] + c[i]; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization of half_t +template <> +struct ElementwiseInnerProduct< + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + half_t, + half_t, + arch::OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = arch::OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 const & B = reinterpret_cast<__half2 const &>(b); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 tmp_D = __hfma2(A, B, C); + + d = reinterpret_cast const &>(tmp_D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[i] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Concept: arch::OpMultiplyAdd or arch::Mma<> + typename Operator = arch::OpMultiplyAdd, + /// Used for partial specialization + typename Enable = bool +> +struct DepthwiseDirectConvElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gemplate that handles all packed matrix layouts +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_, + /// Operator used to compute GEMM + typename Operator_ +> +struct DepthwiseDirectConvElementwiseInnerProductGeneric { + + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = Operator_; + + /// A operand storage + using FragmentA = Array; + + /// B operand storage + using FragmentB = Array; + + /// C operand storage + using FragmentC = Array; + + /// Instruction + using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct< + gemm::GemmShape, + 1, + ElementA, + ElementB, + ElementC, + Operator>; + + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + Array *ptr_D = reinterpret_cast *>(&D); + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array const *ptr_B = + reinterpret_cast const *>(&B); + + MmaOp mma_op; + + // Copy accumulators + D = C; + + // Compute matrix product + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Shape::kM; ++m) { + + Array tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpB = ptr_B[n]; + + mma_op(tmpD, tmpA, tmpB, tmpD); + + ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD; + + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_ +> +struct DepthwiseDirectConvElementwiseInnerProduct< + Shape_, + ElementA_, + ElementB_, + ElementC_, + arch::OpMultiplyAdd + > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = arch::OpMultiplyAdd; + + /// A operand storage + using FragmentA = + Array; // output_tile_size per thread * groups_per_thread + + /// B operand storage + using FragmentB = Array; // 1 * groups_per_thread + + /// C operand storage + using FragmentC = + Array; // output_tile_size per thread * groups_per_thread + + static bool const use_optimized = 0; + + using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric; + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + + ArchMmaOperator mma; + + mma(D, A, B, C); + + } +}; + +} // namespace thread +} // namespace conv +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h new file mode 100644 index 00000000..b396c8c8 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -0,0 +1,485 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradFilterTileAccessIteratorAnalytic; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorAnalytic strided dgrad needs special handling to skip MMAs +// on non-contributing w positions +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ +> { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or larger."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension + int filter_r_; + int filter_s_; + int start_r_; + int start_s_; + int offset_k_[ThreadMap::Iterations::kStrided]; + int offset_c_[ThreadMap::Iterations::kContiguous]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = + threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Moves filter_s + filter_s_ += problem_size_.stride_w; + if (filter_s_ < problem_size_.S) { + return; + } + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r + filter_r_ += problem_size_.stride_h; + if (filter_r_ < problem_size_.R) { + return; + } + // Restore filter_r + filter_r_ = start_r_; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the filter tensor w that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; + + return TensorCoord(k, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the filter tensor w + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorAnalytic unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ +>{ +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or larger."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension + int filter_r_; + int filter_s_; + int offset_k_[ThreadMap::Iterations::kStrided]; + int offset_c_[ThreadMap::Iterations::kContiguous]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = + threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the filter tensor w that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; + + return TensorCoord(k, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the filter tensor w + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h new file mode 100644 index 00000000..a84e0899 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -0,0 +1,619 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradFilterTileAccessIteratorOptimized; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ + > { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Parameters structure + // + + struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams { + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base): + Conv2dStridedDgradFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + Conv2dStridedDgradFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { } + + }; + +private: + + Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_k_; + int filter_r_; + int filter_s_; + + int start_r_; + int start_s_; + + int64_t reset_bytes_s_; + int64_t reset_bytes_r_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided * + ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized( + Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.strided(); + Index column = threadblock_offset.column() + thread_coord.contiguous(); + + reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; + reset_bytes_r_ = reset_bytes_s_ + + (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; + int filter_c = column + c * ThreadMap::Delta::kContiguous; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } + } + } + + TensorCoord coord{filter_k_, filter_r_, filter_s_, column}; + + pointer_ += params_.layout(coord) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void advance() { + + int next_idx = 0; + LongIndex reset_bytes = params_.reset_bytes; + + // Move filter_s by stride_w + filter_s_ += problem_size_.stride_w; + if (filter_s_ >= problem_size_.S) { + + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r by stride_h + filter_r_ += problem_size_.stride_h; +#if 0 + bool check = (filter_r_ < problem_size_.R); + + filter_r_ = check ? filter_r_ : start_r_; + next_idx = check ? 1 : 2; + reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_); +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " .reg .s64 t1;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 t1, %6, %7, %%p;\n\t" + " add.s64 %2, %8, t1;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes)); +#endif + } + + // offset pointers by offset_bytes + pointer_ += (params_.inc_next[next_idx] - reset_bytes); + + if (next_idx == 2) { + filter_k_ += params_.filter_k_delta; + } + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } + } + } + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_[iteration_vector_] & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_strided; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ + > { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Parameters structure + // + + struct Params : Conv2dDgradFilterIteratorOptimizedParams { + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv2dDgradFilterIteratorOptimizedParams const &base): + Conv2dDgradFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + Conv2dDgradFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { } + + }; + +private: + + Conv2dDgradFilterIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_rs_; + int filter_k_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided * + ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized( + Conv2dDgradFilterIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_rs_(0), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.strided(); + Index column = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; + int filter_c = column + c * ThreadMap::Delta::kContiguous; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } + } + } + + pointer_ += ( + filter_k_ * params.layout.stride()[2] + column + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_rs; + + // moves to the next tile + ++filter_rs_; + if (filter_rs_ == params_.RS) { + + filter_rs_ = 0; + next = params_.inc_next_k; + filter_k_ += params_.filter_k_delta; + } + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } + } + } + + pointer_ += next; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_[iteration_vector_] & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_strided; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 00000000..6ed7a556 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,604 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradOutputGradientTileAccessIteratorAnalytic; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using +// unscaled coordinations +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int filter_r_; + int filter_s_; + int start_r_; + int start_s_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + int filter_r = filter_r_; + int filter_s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + filter_r = (problem_size_.R - 1 - filter_r); + filter_s = (problem_size_.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + problem_size_, + stride_h_divmod, stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + // Effective P and Q for filter position required for remapping NHW rows + int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; + int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; + + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; + + // (STEP 1) [reorder NHW rows to start with same filter positions] + offset_n_[s] = offset_npq / (P * Q); + int residual = offset_npq % (P * Q); + + int p = (residual / Q); + int q = (residual % Q); + + int mapped_h = (start_h + p * problem_size_.stride_h); + int mapped_w = (start_w + q * problem_size_.stride_w); + + // Access (p, q) coordinates for Dy tensor and a filter position in gemm_k=0 + // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are divisible + // by stride_h and stride_w + offset_p_[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; + offset_q_[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // Move filter_s by stride_w + filter_s_ += problem_size_.stride_w; + if (filter_s_ < problem_size_.S) { + return; + } + + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r by stride_h + filter_r_ += problem_size_.stride_h; + if (filter_r_ < problem_size_.R) { + return; + } + + // Restore filter_r + filter_r_ = start_r_; + + // Move filter_k + filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int conv_sign = (problem_size_.mode == Mode::kConvolution ? 1 : -1); + + p += (conv_sign * (filter_r_ / problem_size_.stride_h)); + q += (conv_sign * (filter_s_ / problem_size_.stride_w)); + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord( + n, + p, + q, + k); + } + + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return + coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.P && + coord.w() >= 0 && coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by +// eliminating modulo arithmetic to compute unscaled coordinates +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int filter_r_; + int filter_s_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_w_[ThreadMap::Iterations::kStrided]; + int offset_h_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W); + int residual = offset_nhw % (problem_size_.H * problem_size_.W); + + offset_h_[s] = residual / problem_size_.W; + offset_w_[s] = residual % problem_size_.W; + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // move to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int n = offset_n_[iteration_strided_]; + int h = offset_h_[iteration_strided_]; + int w = offset_w_[iteration_strided_]; + + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h; + int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w; + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord(n, p, q, k); + } + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.P && + coord.w() >= 0 && coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // Conv2dDgradFilterTileAccessIteratorAnalytic unity stride specialization + // only supports (stride_h, stride_w) = (1, 1) + if (problem_size.stride() != MatrixCoord({1, 1})) { + return Status::kErrorNotSupported; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 00000000..b0307a5f --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,821 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradOutputGradientTileAccessIteratorOptimized; +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling +// to skip MMAs (Dx = Dy * w) on invalid filter positions +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = uint64_t; + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + int filter_k_; + int filter_r_; + int filter_s_; + int start_r_; + int start_s_; + int64_t reset_bytes_s_; + int64_t reset_bytes_r_; + + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + filter_k_(0), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; + + reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] + + (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_p[ThreadMap::Iterations::kStrided]; + int offset_q[ThreadMap::Iterations::kStrided]; + + int filter_r = filter_r_; + int filter_s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + filter_r = (problem_size_.R - 1 - filter_r); + filter_s = (problem_size_.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + problem_size_, + stride_h_divmod, stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + + // Effective starting P and Q for filter position required for remapping NHW rows + int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; + int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; + + // (STEP 1) [reorder NHW rows to start with same filter positions] + offset_n[s] = offset_npq / (P * Q); + int residual = offset_npq % (P * Q); + + int p = (residual / Q); + int q = (residual % Q); + + int mapped_h = (start_h + p * problem_size_.stride_h); + int mapped_w = (start_w + q * problem_size_.stride_w); + + // Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0 + // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be + // divisible by stride_h and stride_w + offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; + offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; + + // Initialize pointers for gemm_k=0 + TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_}; + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + // + // Precompute mask predicates + // + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int p = offset_p[s_idx] ; + + p += (params_.conv_sign * (r / problem_size_.stride_h)); + + bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int q = offset_q[s_idx]; + q += (params_.conv_sign * (s / problem_size_.stride_w)); + + bool pred = (q >=0 && q < problem_size_.Q); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}); + } + +private: + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset - byte_reset; + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void advance() { + + int next_idx = 0; + int64_t reset_bytes = 0; + + // Move filter_s by stride_w + filter_s_ += problem_size_.stride_w; + if (filter_s_ >= problem_size_.S) { + + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r by stride_h + filter_r_ += problem_size_.stride_h; +#if 0 + if (filter_r_ < problem_size_.R) { + + next_idx = 1; + + // Restore bytes in q coordinate (Mma in filter s dimension) + reset_bytes = reset_bytes_s_; + + } else { + + // Restore filter_r + filter_r_ = start_r_; + + next_idx = 2; + + // Restore bytes in p and q coordinate (Mma in filter s and r dimension) + reset_bytes = reset_bytes_r_; + } +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 %2, %6, %7, %%p;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_)); +#endif + } + + // offset pointers by offset_bytes + add_byte_offset_(params_.inc_next[next_idx] - reset_bytes); + + if (next_idx == 2) { + filter_k_ += params_.filter_k_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + return + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + // Limit on filter size + if (problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad +// with problem stride = {1x1} +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ +> { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = uint64_t; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; + +private: + + Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (r, s) + int filter_r_; + int filter_s_; + int filter_k_; + + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized( + Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + filter_k_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_h[ThreadMap::Iterations::kStrided]; + int offset_w[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_nhw / (problem_size_.H * problem_size_.W); + // int residual = offset_nhw % (problem_size_.H * problem_size_.W); + // + // offset_h[s] = residual / problem_size_.W; + // offset_w[s] = residual % problem_size_.W; + // + + int residual; + + params_.hw_divmod(offset_n[s], residual, offset_nhw); + params_.w_divmod(offset_h[s], offset_w[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_h[s], offset_w[s], 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; + + bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; + + bool pred = (q >= 0 && q < problem_size_.Q); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + /// Returns the coordinate in the output gradient tensor dy that is correspoinding to + // activation nhw and filter position k, r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int h, int w, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; + int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; + + return TensorCoord(n, p, q, filter_k_); + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + filter_s_ = 0; + ++filter_r_; + + if (filter_r_ < problem_size_.R) { + next_idx = 1; + } + else { + filter_r_ = 0; + next_idx = 2; + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 2) { + filter_k_ += params_.filter_k_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // This is specialized for unit stride + if (problem_size.stride() != MatrixCoord({1, 1})) { + return Status::kErrorNotSupported; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorNotSupported; + } + + // Limit on filter size + if (problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h new file mode 100644 index 00000000..1f7396e4 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -0,0 +1,332 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone +> +class Conv2dFpropActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_c_; + int filter_r_; + int filter_s_; + int filter_c_init_; + int group_idx_offset_; + int channels_per_group_; + int crs_cnt_; + int crs_per_group_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_(0), + filter_c_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + channels_per_group_ = problem_size_.C / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn); + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); + int residual = offset_npq % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + // moves to next group + crs_cnt_ = 0; + ++group_idx_offset_; + filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_; + } else { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } + } + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - filter_r_); + s = (problem_size_.S - 1 - filter_s_); + } + + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + int c = filter_c_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h new file mode 100644 index 00000000..5a4489c0 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropActivationTileAccessIteratorFewChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kPositionsPerTile = Shape::kColumn; + + static int const kAccessesPerVector = kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static int const kStrideH = 0; + static int const kStrideW = 0; + static int const kDilationH = 0; + static int const kDilationW = 0; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rsc_index_; + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFewChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rsc_index_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rsc_index_ = (threadblock_offset.column() + thread_coord.contiguous()); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + if (kUseFastDivmodPrologue) { + int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); + offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); + } + else { + offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); + int residual = offset_npq % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; + + int r = 0; + int s = 0; + int c = 0; + + if (kUseFastDivmodMainloop) { + int rs_index = params_.divmod_C.divmod(c, rsc_index); + r = params_.divmod_S.divmod(s, rs_index); + } + else { + c = (rsc_index % problem_size_.C); + + int rs_index = (rsc_index / problem_size_.C); + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int stride_h = kStrideH; + if (!kStrideH) { + stride_h = problem_size_.stride_h; + } + + int stride_w = kStrideW; + if (!kStrideW) { + stride_w = problem_size_.stride_w; + } + + int dilation_h = kDilationH; + if (!kDilationH) { + dilation_h = problem_size_.dilation_h; + } + + int dilation_w = kDilationW; + if (!kDilationW) { + dilation_w = problem_size_.dilation_w; + } + + int h = p * stride_h - problem_size_.pad_h + r * dilation_h; + int w = q * stride_w - problem_size_.pad_w + s * dilation_w; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + bool in_bounds = + coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + + return in_bounds; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + + coord.c(); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFewChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (kDilationH && problem_size.dilation_h != kDilationH) { + return Status::kErrorInvalidProblem; + } + + if (kDilationW && problem_size.dilation_w != kDilationW) { + return Status::kErrorInvalidProblem; + } + + if (kStrideH && problem_size.stride_h != kStrideH) { + return Status::kErrorInvalidProblem; + } + + if (kStrideW && problem_size.stride_w != kStrideW) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h new file mode 100644 index 00000000..3f1f2bc1 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h @@ -0,0 +1,353 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropActivationTileAccessIteratorFixedChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kFilterPositionsPerTile = Shape::kColumn / AccessType::kElements; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static int const kStrideH = 0; + static int const kStrideW = 0; + static int const kDilationH = 0; + static int const kDilationW = 0; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rs_index_; + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFixedChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rs_index_(0) { + + // + // This requires problem_size.C == AccessType::kElements + // + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rs_index_ = (threadblock_offset.column() + thread_coord.contiguous()) / AccessType::kElements; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + if (kUseFastDivmodPrologue) { + int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); + offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); + } + else { + offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); + int residual = offset_npq % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int rs_index = rs_index_ + iteration_vector_; + + int r = 0; + int s = 0; + + if (kUseFastDivmodMainloop) { + r = params_.divmod_S.divmod(s, rs_index); + } + else { + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int stride_h = kStrideH; + if (!kStrideH) { + stride_h = problem_size_.stride_h; + } + + int stride_w = kStrideW; + if (!kStrideW) { + stride_w = problem_size_.stride_w; + } + + int dilation_h = kDilationH; + if (!kDilationH) { + dilation_h = problem_size_.dilation_h; + } + + int dilation_w = kDilationW; + if (!kDilationW) { + dilation_w = problem_size_.dilation_w; + } + + int h = p * stride_h - problem_size_.pad_h + r * dilation_h; + int w = q * stride_w - problem_size_.pad_w + s * dilation_w; + + return TensorCoord(n, h, w, 0); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + coord.c(); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFixedChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C != AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (kDilationH && problem_size.dilation_h != kDilationH) { + return Status::kErrorInvalidProblem; + } + + if (kDilationW && problem_size.dilation_w != kDilationW) { + return Status::kErrorInvalidProblem; + } + + if (kStrideH && problem_size.stride_h != kStrideH) { + return Status::kErrorInvalidProblem; + } + + if (kStrideW && problem_size.stride_w != kStrideW) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h new file mode 100644 index 00000000..43056a69 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -0,0 +1,422 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + using Mask = uint64_t; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFpropActivationIteratorOptimizedParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (r, s) + int filter_r_; + int filter_s_; + int filter_c_; + + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + filter_c_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_p[ThreadMap::Iterations::kStrided]; + int offset_q[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_npq / (problem_size_.P * problem_size_.Q); + // int residual = offset_npq % (problem_size_.P * problem_size_.Q); + // + // offset_p[s] = residual / problem_size_.Q; + // offset_q[s] = residual % problem_size_.Q; + // + + int residual; + + params.pq_divmod(offset_n[s], residual, offset_npq); + params.q_divmod(offset_p[s], offset_q[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_p[s], offset_q[s], 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; + + bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; + + bool pred = (w >= 0 && w < problem_size_.W); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + /// Returns the coordinate in the activations tensor X that is correspoinding to + // output npq and filter position r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int p, int q, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, h, w, filter_c_); + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + filter_s_ = 0; + ++filter_r_; + + if (filter_r_ < problem_size_.R) { + next_idx = 1; + } + else { + filter_r_ = 0; + next_idx = 2; + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 2) { + filter_c_ += params_.filter_c_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; + masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; + } + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; + masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorOptimized &operator++() { + + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + // Conv2dFpropActivationTileAccessIteratorOptimized has constraint on filter positions + // due to the number of mask bits. + if (problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h new file mode 100644 index 00000000..434c288a --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone +> +class Conv2dFpropFilterTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_r_; + int filter_s_; + int filter_c_; + int filter_c_init_; + int crs_cnt_; + int crs_per_group_; + int group_idx_offset_c_; + int channels_per_group_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + int group_idx_offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_c_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + if (kGroupMode == conv::GroupMode::kDepthwise){ + channels_per_group_ = 1; + crs_per_group_ = problem_size_.S * problem_size_.R; + } else { + channels_per_group_ = problem_size_.C / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { + group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups); + } + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + crs_cnt_ = 0; + filter_c_ = filter_c_init_; + if (kGroupMode != conv::GroupMode::kDepthwise) { + // moves to next group + ++group_idx_offset_c_; + } + } else { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } + } + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + int c = filter_c_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord(k, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + if (kGroupMode == conv::GroupMode::kNone) { + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } else if (kGroupMode == conv::GroupMode::kDepthwise) { + return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE. + } else { + return coord.n() < problem_size_.K && coord.c() < channels_per_group_ && + group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; + } + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.K % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.K % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h new file mode 100644 index 00000000..a1291aa0 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h @@ -0,0 +1,289 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropFilterTileAccessIteratorFewChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kPositionsPerTile = Shape::kRow; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rsc_index_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFewChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rsc_index_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rsc_index_ = (threadblock_offset.row() + thread_coord.contiguous()); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; + + int c = 0; + int s = 0; + int r = 0; + + if (kUseFastDivmodMainloop) { + int rs_index = params_.divmod_C.divmod(c, rsc_index); + r = params_.divmod_S.divmod(s, rs_index); + } + else { + c = (rsc_index % problem_size_.C); + int rs_index = (rsc_index / problem_size_.C); + + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, r, s, c); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + bool in_bounds = + coord.n() < problem_size_.K && + coord.h() >= 0 && + coord.h() < problem_size_.R && + coord.c() < problem_size_.C; + + return in_bounds; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + + coord.c(); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFewChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.K % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.K % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h new file mode 100644 index 00000000..e90d5017 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h @@ -0,0 +1,275 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropFilterTileAccessIteratorFixedChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kFilterPositionsPerTile = Shape::kRow / AccessType::kElements; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rs_index_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFixedChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rs_index_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rs_index_ = (threadblock_offset.row() + thread_coord.contiguous()) / AccessType::kElements; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int rs_index = rs_index_ + iteration_vector_; + + int r = 0; + int s = 0; + + if (kUseFastDivmodMainloop) { + r = params_.divmod_S.divmod(s, rs_index); + } + else { + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, r, s, 0); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.h() >= 0 && coord.h() < problem_size_.R; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + coord.c(); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFixedChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C != AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.K % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.K % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h new file mode 100644 index 00000000..37aeb4f6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -0,0 +1,317 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropFilterTileAccessIteratorOptimized{ +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params : Conv2dFpropFilterIteratorOptimizedParams { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv2dFpropFilterIteratorOptimizedParams const &base): + Conv2dFpropFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + Conv2dFpropFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { + + } + }; + +private: + + Conv2dFpropFilterIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_rs_; + int filter_c_; + int channels_per_group_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorOptimized( + Conv2dFpropFilterIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_rs_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + Index column = threadblock_offset.column() + thread_coord.strided(); + channels_per_group_ = problem_size_.C / problem_size_.groups; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + predicates_[v_idx] |= (pred << s); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); + } + + pointer_ += ( + params_.layout({filter_c_, column}) + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_rs; + + // moves to the next tile + ++filter_rs_; + if (filter_rs_ == params_.RS) { + + filter_rs_ = 0; + next = params_.inc_next_c; + filter_c_ += params_.filter_c_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); + } + + pointer_ += next; + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + predicates_[v] = clear ? 0u : predicates_[v]; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + return (predicates_[iteration_vector_] & (1u << iteration_strided_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_k; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.K % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.K % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_params.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_params.h new file mode 100644 index 00000000..30c21da6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_params.h @@ -0,0 +1,893 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Params structure used for all Conv2d analytic tile iterators +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dAnalyticParams { + + using Layout = Layout_; + + Layout layout; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dAnalyticParams() { } + + CUTLASS_HOST_DEVICE + Conv2dAnalyticParams( + Conv2dProblemSize const &, // unused; placeholder to match other Params interfaces. + Layout const &layout + ): layout(layout) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Params structure used for all Conv2d analytic tile iterators +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dFewChannelsParams { + + using Layout = Layout_; + + + int32_t stride_w; + int32_t stride_h; + int32_t stride_n; + + FastDivmod divmod_P; + FastDivmod divmod_Q; + FastDivmod divmod_S; + FastDivmod divmod_C; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dFewChannelsParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFewChannelsParams( + Conv2dProblemSize const &problem_size, // unused; placeholder to match other Params interfaces. + Layout const &layout + ): + stride_w(int32_t(layout.stride()[0])), + stride_h(int32_t(layout.stride()[1])), + stride_n(int32_t(layout.stride()[2])), + divmod_P(problem_size.P), + divmod_Q(problem_size.Q), + divmod_S(problem_size.S), + divmod_C(problem_size.C) + { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams +struct Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + int tiled_rows_per_filter; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams() { } + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape + ): layout(layout) { + + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); + + tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED + +CUTLASS_HOST_DEVICE +void TraceIteratorParams( + char const *conv_operator, + char const *operand, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta +) { + +#if !defined(__CUDA_ARCH__) + + char const *fname = "conv_iterator_params.csv"; + + std::ifstream test(fname); + bool file_exists = test.is_open(); + + if (file_exists) { + test.close(); + } + + std::ofstream trace("conv_iterator_params.csv", std::ofstream::app); + + if (!file_exists) { + trace + << "Operator,Operand,ElementSize,CtaRows,CtaColumns,ThreadCount,AccessSize," + << "IterationsContiguous,IterationsStrided,DeltaContiguous,DeltaStrided\n"; + } + + trace << conv_operator << "," << operand << "," << element_size_bits << "," + << threadblock_shape.row() << "," << threadblock_shape.column() + << "," << thread_count << "," << access_size + << "," << threadmap_iterations.contiguous() << "," << threadmap_iterations.strided() + << "," << threadmap_delta.contiguous() << "," << threadmap_delta.strided() << "\n"; +#endif +} + +#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) \ + TraceIteratorParams(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta); + +#else + +#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) {} + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dFpropActivationIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized +template<> +struct Conv2dFpropActivationIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next C} + int filter_c_delta; // number of logical elements to add to filter_c_ + int PQ; // product of P*Q + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + PQ(problem_size.P * problem_size.Q), + pq_divmod(PQ), + q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); + + // next S + inc_next[0] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[1]) * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next C + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; + } + +#if ENABLE_CONV2D_PARAMS_PRINT + /// Prints internal state. + CUTLASS_HOST_DEVICE + void print() { + auto stride = layout.stride(); + printf( + "Conv2dFpropActivationIteratorOptimizedParams:\n" + " layout(w: %d, h: %d, n: %d)\n" + " inc_next[%ld, %ld, %ld]\n" + " filter_c_delta(%d) - PQ(%d)\n" + " pq_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n" + " q_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n", + stride[0], stride[1], stride[2], + inc_next[0], inc_next[1], inc_next[2], + filter_c_delta, + PQ, + pq_divmod.divisor, + pq_divmod.multiplier, + pq_divmod.shift_right, + q_divmod.divisor, + q_divmod.multiplier, + q_divmod.shift_right + ); + } +#endif +}; + +/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized +template +struct Conv2dFpropActivationIteratorOptimizedParams> { + static int const kInterleaved = Interleaved_; + + using Layout = layout::TensorNCxHWx; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next C} + int filter_c_delta; // number of logical elements to add to filter_c_ + int PQ; // product of P*Q + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), PQ(problem_size.P * problem_size.Q), pq_divmod(PQ), q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); + + // next S + inc_next[0] = conv_sign * (kInterleaved * problem_size.dilation_w) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_h + - (problem_size.S - 1) * kInterleaved * problem_size.dilation_w + ) * element_size_bits / 8; + + // next C + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[1]) + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[0] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * kInterleaved * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dFpropFilterIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Conv2dFpropFilterIteratorOptimizedParams +{ + + using Layout = layout::TensorNHWC; + + Layout layout; + int RS; + int filter_c_delta; + + int64_t inc_next_k; // offset in units of bytes to next K position + int64_t inc_next_rs; // offset in units of bytes to next RS position + int64_t inc_next_c; // offset in units of bytes to next C position + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + RS = problem_size.R * problem_size.S; + + inc_next_k = (int64_t(layout.stride()[2]) * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_rs = + ( int64_t(layout.stride()[0]) + - int64_t(layout.stride()[2]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() + ) * element_size_bits / 8; + + inc_next_c = + ( + threadblock_shape.row() * problem_size.split_k_slices + - int64_t(RS - 1) * layout.stride()[0] + - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; + } + +#if ENABLE_CONV2D_PARAMS_PRINT + /// Prints internal state. + CUTLASS_HOST_DEVICE + void print() { + auto stride = layout.stride(); + printf( + "Conv2dFpropFilterIteratorOptimizedParams:\n" + " layout[%d, %d, %d]\n" + " RS(%d), filter_c_delta(%d), inc_next(k: %ld, rs: %ld, c: %ld)\n", + stride[0], stride[1], stride[2], + RS, + filter_c_delta, + inc_next_k, inc_next_rs, inc_next_c + ); + } +#endif +}; + +template +struct Conv2dFpropFilterIteratorOptimizedParams> +{ + static int const kInterleaved = Interleaved_; + using Layout = layout::TensorCxRSKx; + + Layout layout; + int RS; + int filter_c_delta; + + int64_t inc_next_k; // offset in units of bytes to next K position + int64_t inc_next_rs; // offset in units of bytes to next RS position + int64_t inc_next_c; // offset in units of bytes to next C position + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + RS = problem_size.R * problem_size.S; + + inc_next_k = (kInterleaved * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_rs = + ( int64_t(layout.stride()[0]) + - kInterleaved * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() + ) * element_size_bits / 8; + + inc_next_c = + ( + threadblock_shape.row() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[2]) + - int64_t(RS - 1) * layout.stride()[0] + - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * kInterleaved + ) * element_size_bits / 8; + + filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Dgrad Optimized Dy params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Parameters object for Conv2d DGRAD OutputGradient (dy) iterator +struct Conv2dDgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next K} + + int filter_k_delta; // number of logical elements to add to filter_k_ + + int HW; // product of H*W + + FastDivmod hw_divmod; + FastDivmod w_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + HW(problem_size.H *problem_size.W), + hw_divmod(HW), + w_divmod(problem_size.W) { + + TRACE_CONV_INITIALIZERS("conv2d_dgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); + + // next S + inc_next[0] = conv_sign * ( + (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + (int64_t)layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next K + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h + - conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad Optimized Dy params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next K} + + int filter_k_delta; // number of logical elements to add to filter_k_ + + int tiled_rows_per_filter; + + int conv_sign; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dStridedDgradOutputGradientIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape + ): layout(layout) { + + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); + + tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); + + conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); + + // next S + inc_next[0] = conv_sign * ( + (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + (int64_t)layout.stride()[1] * problem_size.dilation_h + ) * element_size_bits / 8; + + // next K + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////// +// Dgrad Optimized w params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +struct Conv2dDgradFilterIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + int RS; + int filter_k_delta; + + int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile + int64_t inc_next_rs; // offset in units of bytes to next RS position + int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dDgradFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), RS(problem_size.R * problem_size.S) { + + TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_rs = + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] + ) * element_size_bits / 8; + + inc_next_k = + ( + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] + - (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] + ) * element_size_bits / 8; + + filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////// +// StridedDgrad Optimized w params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +struct Conv2dStridedDgradFilterIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + int RS; + int filter_k_delta; + + int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile + int64_t inc_next[3]; // {next S, next R, next K} + int64_t reset_bytes; // offset in units of bytes to move back the pointer + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dStridedDgradFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dStridedDgradFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), RS(problem_size.R * problem_size.S) { + + TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; + + // next S + inc_next[0] = + ( (int64_t)layout.stride()[0] * problem_size.stride_w + //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + // next R + inc_next[1] = + ( (int64_t)layout.stride()[1] * problem_size.stride_h + //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + // next K + inc_next[2] = + ( + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] + //- (problem_size.R * problem_size.S - 1) * layout.stride()[0] + //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + // offset in units of bytes to move the pointer in backward direction + reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] + * element_size_bits / 8; + + filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator +struct Conv2dWgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int NPQ; // precomputd product of N*P*Q for clearing predicates + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int64_t offset_next_strided; // offset in units of bytes to next npq coordinate within tile + int64_t offset_next_contiguous; // offset in units of bytes to next k coordinate within tile + int64_t inc_next_npq; // offset in units of bytes to next npq position in subsequent tile + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + NPQ(problem_size.N * problem_size.P * problem_size.Q), + pq_divmod(problem_size.P * problem_size.Q), + q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv2d_wgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + + offset_next_contiguous = (threadmap_delta.contiguous()) + * element_size_bits / 8; + + inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + } +}; + +struct Conv2dWgradActivationIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + FastDivmod sc_divmod; + FastDivmod pq_divmod; + FastDivmod q_divmod; + FastDivmod c_divmod; + FastDivmod s_divmod; + int small_channel_conv_s_offset; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dWgradActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + layout(layout), + sc_divmod(problem_size.S * problem_size.C), + pq_divmod(problem_size.P * problem_size.Q), + q_divmod(problem_size.Q), + c_divmod(problem_size.C), + s_divmod(problem_size.S * problem_size.dilation_w), + small_channel_conv_s_offset((problem_size.S - 1) * problem_size.dilation_w - problem_size.pad_w) { + } + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + Conv2dWgradActivationIteratorOptimizedParams( + problem_size, + layout + ) { + + TRACE_CONV_INITIALIZERS("conv2d_wgrad", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + } +}; + +struct PredicatedScaleBiasVectorAccessIteratorParams { + public: + /// Default ctor + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIteratorParams() { } + + // Default ctor + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIteratorParams( + Conv2dProblemSize const &problem_size, + layout::PitchLinear const &layout) {} + + // Default ctor + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIteratorParams( + Conv2dProblemSize const &problem_size, + layout::RowMajor const &layout) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_tile_iterator.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_tile_iterator.h new file mode 100644 index 00000000..150ff689 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -0,0 +1,336 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template wraps the tile access iterator concept to load whole tiles from tensors in + memory used for implicit GEMM convolution. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TileIterator { +public: + using TileAccessIterator = TileAccessIterator_; + + using Shape = typename TileAccessIterator::Shape; + using Element = typename TileAccessIterator::Element; + using Layout = typename TileAccessIterator::Layout; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = typename TileAccessIterator::ThreadMap; + using AccessType = typename TileAccessIterator::AccessType; + using TensorRef = typename TileAccessIterator::TensorRef; + using Index = typename TileAccessIterator::Index; + using LongIndex = typename TileAccessIterator::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; + using Params = typename TileAccessIterator::Params; + static int const kConvDim = TileAccessIterator::kConvDim; + using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + +private: + + /// Internal state + TileAccessIterator tile_access_iterator_; + +public: + + /// Constructor + CUTLASS_HOST_DEVICE + TileIterator( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_(params, problem_size, ptr, thread_idx, threadblock_offset) { } + + CUTLASS_HOST_DEVICE + static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { + return TileAccessIterator::getParams(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + tile_access_iterator_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + tile_access_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIterator &operator++() { + tile_access_iterator_.advance(); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIterator operator++(int) { + TileIterator self(*this); + operator++(); + return self; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.clear(); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[idx], + tile_access_iterator_.get() + pointer_offset, + tile_access_iterator_.valid() + ); + + ++tile_access_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + tile_access_iterator_.set_iteration_index(0); + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void advance() { + tile_access_iterator_.advance(); + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // dispatch to iterator implementation + return TileAccessIterator::can_implement(problem_size); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad Tile Iterator +template +class TileIteratorStridedDgrad { +public: + using TileAccessIterator = TileAccessIterator_; + + using Shape = typename TileAccessIterator::Shape; + using Element = typename TileAccessIterator::Element; + using Layout = typename TileAccessIterator::Layout; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = typename TileAccessIterator::ThreadMap; + using AccessType = typename TileAccessIterator::AccessType; + using TensorRef = typename TileAccessIterator::TensorRef; + using Index = typename TileAccessIterator::Index; + using LongIndex = typename TileAccessIterator::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; + using Params = typename TileAccessIterator::Params; + static int const kConvDim = TileAccessIterator::kConvDim; + using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + +private: + + /// Internal state + TileAccessIterator tile_access_iterator_; + +public: + + /// Constructor (output gradient (Dy) OperandA ctor) + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_( + params, + problem_size, + ptr, + thread_idx, + stride_h_divmod, stride_w_divmod, + start_r, start_s, + threadblock_offset) { } + + /// Constructor (filter (w) OperandB ctor) + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_(params, + problem_size, + ptr, + thread_idx, + start_r, start_s, + threadblock_offset) { } + + CUTLASS_HOST_DEVICE + static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { + return TileAccessIterator::getParams(problem_size, layout); + } + + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + tile_access_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad &operator++() { + tile_access_iterator_.advance(); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad operator++(int) { + TileIteratorStridedDgrad self(*this); + operator++(); + return self; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.clear(); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[c + s * ThreadMap::Iterations::kContiguous], + tile_access_iterator_.get() + pointer_offset, + tile_access_iterator_.valid() + ); + + ++tile_access_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + tile_access_iterator_.set_iteration_index(0); + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void advance() { + tile_access_iterator_.advance(); + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // dispatch to iterator implementation + return TileAccessIterator::can_implement(problem_size); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h new file mode 100644 index 00000000..649201a2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k + int filter_r_[ThreadMap::Iterations::kContiguous]; + int filter_s_[ThreadMap::Iterations::kContiguous]; + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_npq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) + { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); + int residual = rsc_offset % (problem_size_.S * problem_size_.C); + + filter_s_[c] = residual / problem_size_.C; + filter_c_[c] = residual % problem_size_.C; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int r, s, c; + + if (kAccessesPerVector == 1) { + /// One 128b aligned access fetching more than one element + c = filter_c_[iteration_contiguous_]; + r = filter_r_[iteration_contiguous_]; + s = filter_s_[iteration_contiguous_]; + } + else { + /// Multiple access to support non-128b alignment in contiguous dimension + c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C; + int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C; + s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S; + int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S; + r = filter_r_[iteration_contiguous_] + wrap_s; + } + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); + int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h new file mode 100644 index 00000000..8b772eb2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -0,0 +1,321 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dWgradActivationIteratorOptimizedParams; + +private: + + Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k + // required for npq -> nhw translation + int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; + int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; + + // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_npq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorOptimized( + Conv2dWgradActivationIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) + { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); + // int residual = rsc_offset % (problem_size_.S * problem_size_.C); + // + // filter_s_[c] = residual / problem_size_.C; + // filter_c_[c] = residual % problem_size_.C; + + int residual; + params_.sc_divmod(precomputed_filter_r_[c], residual, rsc_offset); + params_.c_divmod(precomputed_filter_s_[c], filter_c_[c], residual); + + int r = precomputed_filter_r_[c]; + int s = precomputed_filter_s_[c]; + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h; + precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int r = precomputed_filter_r_[iteration_contiguous_]; + int s = precomputed_filter_s_[iteration_contiguous_]; + int c = filter_c_[iteration_contiguous_]; + + if (kAccessesPerVector > 1) { + // This code section is only to support non-128b alignment + // Multiple access to support non-128b alignment in contiguous dimension + int wrap_c; + params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements); + + if (problem_size_.mode == Mode::kConvolution) { + s -= (problem_size_.dilation_w * wrap_c); + + int wrap_s; + params_.s_divmod(wrap_s, s, params_.small_channel_conv_s_offset - s); + s = params_.small_channel_conv_s_offset - s; + + r -= (problem_size_.dilation_h * wrap_s); + + } else { + s += (problem_size_.dilation_w * wrap_c); + + int wrap_s; + params_.s_divmod(wrap_s, s, s + problem_size_.pad_w); + s -= problem_size_.pad_w; + + r += (problem_size_.dilation_h * wrap_s); + } + } + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); + // int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, p, q; + + params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]); + params_.q_divmod(p, q, residual); + + int h = p * problem_size_.stride_h + r; + int w = q * problem_size_.stride_w + s; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 00000000..c8e2f519 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,260 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_[ThreadMap::Iterations::kContiguous]; + + int offset_npq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize filter_k for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] = threadblock_offset.column() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] += Shape::kColumn * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int npq = offset_npq_[iteration_strided_]; + + int n = npq / (problem_size_.P * problem_size_.Q); + int residual = npq % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; + + return TensorCoord(n, p, q, k); + } + + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() < problem_size_.P && + coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 00000000..350b652b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,310 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradOutputGradientTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; + +private: + + Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_k_; + int offset_npq_; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorOptimized( + Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_k_(0), + offset_npq_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); + offset_npq_ = threadblock_offset.column() + thread_coord.strided(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; + int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements)); + + uint32_t pred = (predicate ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } + } + } + + // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) + pointer_ += ( + offset_npq_ * params.layout.stride()[0] + filter_k_ + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile + offset_npq_ += Shape::kColumn * problem_size_.split_k_slices; + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } + } + } + + pointer_ += params_.inc_next_npq; + } + +private: + /// Returns the coordinate in the output gradient tensor Dy that is pointed to + /// by offset_npq and k. + CUTLASS_HOST_DEVICE + TensorCoord at_(int offset_npq, int k) const { + + // The subsequent fast_divmod() operations are equivalent to the following logical computation: + // + // + // int npq = offset_npq; + // int n = npq / (problem_size_.P * problem_size_.Q); + // int residual = npq % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, p, q; + + params_.pq_divmod(n, residual, offset_npq); + params_.q_divmod(p, q, residual); + + return TensorCoord(n, p, q, k); + } + + /// Returns true if the coord is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid_(TensorCoord coord) const { + + return coord.n() < problem_size_.N && + coord.c() < problem_size_.K; + } + +public: + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_[iteration_vector_] & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast( + pointer_ + + iteration_strided_ * params_.offset_next_strided + + iteration_contiguous_ * params_.offset_next_contiguous + ) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h new file mode 100644 index 00000000..42eca7c9 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -0,0 +1,268 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dDgradFilterTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or larger."); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + // For a fixed filter position (t,r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension + int filter_t_; + int filter_r_; + int filter_s_; + int offset_k_[ThreadMap::Iterations::kStrided]; + int offset_c_[ThreadMap::Iterations::kContiguous]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_t_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = + threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the filter tensor w that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int c = offset_c_[iteration_contiguous_]; + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, filter_t_, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the filter tensor w + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h new file mode 100644 index 00000000..5dcf8b53 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -0,0 +1,289 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity +> +class Conv3dDgradFilterTileAccessIteratorOptimized { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = StrideSupport_; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Parameters structure + // + + struct Params : Conv3dDgradFilterIteratorOptimizedParams { + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv3dDgradFilterIteratorOptimizedParams const &base): + Conv3dDgradFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): + Conv3dDgradFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { } + + }; + +private: + + Conv3dDgradFilterIteratorOptimizedParams const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + uint32_t predicates_; + int filter_trs_; + int filter_k_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided * + ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorOptimized( + Conv3dDgradFilterIteratorOptimizedParams const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_(0), + filter_trs_(0), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.strided(); + Index column = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; + int filter_c = column + c * ThreadMap::Delta::kContiguous; + + uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_ |= (pred << pred_idx); + } + } + + pointer_ += ( + filter_k_ * params.layout.stride()[3] + column + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_trs; + + // moves to the next tile + ++filter_trs_; + if (filter_trs_ == params_.TRS) { + + filter_trs_ = 0; + next = params_.inc_next_k; + filter_k_ += params_.filter_k_delta; + } + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + predicates_ = (predicates_ & (~kClearMask)); + } + } + + pointer_ += next; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_ & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_strided; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 00000000..22025788 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,343 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided +> +class Conv3dDgradOutputGradientTileAccessIteratorAnalytic; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv3dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using +// unscaled coordinations +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dDgradOutputGradientTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + ConvProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_k_; + int filter_t_; + int filter_r_; + int filter_s_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_d_[ThreadMap::Iterations::kStrided]; + int offset_w_[ThreadMap::Iterations::kStrided]; + int offset_h_[ThreadMap::Iterations::kStrided]; + +private: + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator but DOES NOT scale by the convolution stride. This is needed + /// to compute predicates in the valid() method. The return value of the public at() + /// method is correctly scaled. + CUTLASS_HOST_DEVICE + TensorCoord unscaled_at_() const { + int n = offset_n_[iteration_strided_]; + int d = offset_d_[iteration_strided_]; + int h = offset_h_[iteration_strided_]; + int w = offset_w_[iteration_strided_]; + + int t = filter_t_; + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - t); + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int z = (d + problem_size_.pad_d - t * problem_size_.dilation_d); + int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h); + int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w); + + return TensorCoord(n, z, p, q, filter_k_); + } + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0), + filter_t_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); + int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); + + offset_d_[s] = residual / (problem_size_.H * problem_size_.W); + residual = residual % (problem_size_.H * problem_size_.W); + + offset_h_[s] = residual / problem_size_.W; + offset_w_[s] = residual % problem_size_.W; + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // move to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + TensorCoord coord = unscaled_at_(); + + return TensorCoord( + coord.n(), + coord.d() / problem_size_.stride_d, + coord.h() / problem_size_.stride_h, + coord.w() / problem_size_.stride_w, + coord.c()); + } + + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord unscaled_coord = unscaled_at_(); + TensorCoord coord = at(); + + return + !(unscaled_coord.d() % problem_size_.stride_d) && + !(unscaled_coord.h() % problem_size_.stride_h) && + !(unscaled_coord.w() % problem_size_.stride_w) && + coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.Z && + coord.h() >= 0 && coord.h() < problem_size_.P && + coord.w() >= 0 && coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 00000000..f5ada7f5 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity +> +class Conv3dDgradOutputGradientTileAccessIteratorOptimized { +public: + + static_assert(StrideSupport_ == conv::StrideSupport::kUnity, + "Only unit-stride dgrad is supported at this time."); + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + using Coord3D = Coord<3>; + static int const kAccessesPerVector = 1; + using Mask = uint64_t; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (t, r, s) + int filter_t_; + int filter_r_; + int filter_s_; + int filter_k_; + + Index masks_[ThreadMap::Iterations::kStrided][3]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorOptimized( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + filter_k_(0), + filter_t_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_d[ThreadMap::Iterations::kStrided]; + int offset_h[ThreadMap::Iterations::kStrided]; + int offset_w[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); + // int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); + // + // + // offset_d[s] = residual / (problem_size_.H * problem_size_.W); + // residual = residual % (problem_size_.H * problem_size_.W); + // + // offset_h[s] = residual / problem_size_.W; + // offset_w[s] = residual % problem_size_.W; + // + + int residual; + + // input: (ndhw offset) output: (n offset and resudial (dhw offset)) + params_.dhw_divmod(offset_n[s], residual, offset_ndhw); + // input: (dhw offset) output: (d offset and resudial (hw)) + params_.hw_divmod(offset_d[s], residual, residual); + // input: (hw offset) output: (h offset and resudial (w offset)) + params_.w_divmod(offset_h[s], offset_w[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_d[s], offset_h[s], offset_w[s], 0, 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int t = 0; t < problem_size_.T; ++t) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int t_ = t; + if (problem_size_.mode == Mode::kConvolution) { + t_ = problem_size_.T - 1 - t; + } + + int z = offset_d[s_idx] + problem_size_.pad_d - t_ * problem_size_.dilation_d; + + bool pred = (offset_n[s_idx] < problem_size_.N && z >= 0 && z < problem_size_.Z); + masks_[s_idx][0] |= (pred << t); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; + + bool pred = (p >= 0 && p < problem_size_.P); + masks_[s_idx][1] |= (pred << r); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; + + bool pred = (q >= 0 && q < problem_size_.Q); + masks_[s_idx][2] |= (pred << s); + } + } + + if (filter_k_ >= problem_size.K) { + clear_mask(); + } + + set_iteration_index(0); + + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + + /// Returns the coordinate in the output gradient tensor dy that is correspoinding to + // activation ndhw and filter position k, t, r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int d, int h, int w, int t, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + t = problem_size_.T - 1 - t; + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int z = d + problem_size_.pad_d - t * problem_size_.dilation_d; + int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; + int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; + + return TensorCoord(n, z, p, q, filter_k_); + } + + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask_(bool clear) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + // We are using inline PTX assembly here to avoid an CUDA C++ compilation + // artifact in which control flow instructions are generated. Instead, our + // intent is to predicate the mov instructions. + #if defined(__CUDA_ARCH__) + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][0]) + : + "r"((int)clear), + "r"(masks_[s][0]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][1]) + : + "r"((int)clear), + "r"(masks_[s][1]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][2]) + : + "r"((int)clear), + "r"(masks_[s][2]) + ); + #else + if (clear) { + masks_[s][0] = 0; + masks_[s][1] = 0; + masks_[s][2] = 0; + } + #endif + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + + filter_s_ = 0; + ++filter_r_; + next_idx = 1; + + if (filter_r_ == problem_size_.R) { + filter_r_ = 0; + ++filter_t_; + + if (filter_t_ < problem_size_.T) { + next_idx = 2; + } + else { + filter_t_ = 0; + next_idx = 3; + } + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 3) { + filter_k_ += params_.filter_k_delta; + } + + clear_mask_(filter_k_ >= problem_size_.K); + } + + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask() { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][0] = Mask(0); + masks_[s][1] = Mask(0); + masks_[s][2] = Mask(0); + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && + (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorOptimized &operator++() { + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // This is specialized for unit stride + if (problem_size.stride() != Coord3D({1, 1, 1})) { + return Status::kErrorNotSupported; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % (128/sizeof_bits::value)) { + return Status::kErrorNotSupported; + } + + // Limit on filter size + if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h new file mode 100644 index 00000000..a858eae2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -0,0 +1,289 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dFpropActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dAnalyticParams; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_t_; + int filter_r_; + int filter_s_; + int filter_c_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_z_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorAnalytic( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_t_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + + offset_z_[s] = residual / (problem_size_.P * problem_size_.Q); + residual = residual % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int z = offset_z_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int t = filter_t_; + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - filter_t_); + r = (problem_size_.R - 1 - filter_r_); + s = (problem_size_.S - 1 - filter_s_); + } + + int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, d, h, w, filter_c_); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.D && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h new file mode 100644 index 00000000..9acc27dc --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -0,0 +1,478 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_ +> +class Conv3dFpropActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + using Mask = uint64_t; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dFpropActivationIteratorOptimizedParams; + +private: + + Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (t, r, s) + int filter_t_; + int filter_r_; + int filter_s_; + int filter_c_; + + // mask for t, r, and s + Index masks_[ThreadMap::Iterations::kStrided][3]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorOptimized( + Conv3dFpropActivationIteratorOptimizedParams const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ) : + params_(params), + problem_size_(problem_size), + filter_t_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_z[ThreadMap::Iterations::kStrided]; + int offset_p[ThreadMap::Iterations::kStrided]; + int offset_q[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + // int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + // + // offset_z[s] = residual / (problem_size_.P * problem_size_.Q); + // residual = residual % (problem_size_.P * problem_size_.Q); + // + // offset_p[s] = residual / problem_size_.Q; + // offset_q[s] = residual % problem_size_.Q; + // + + int residual; + + // input: (nzpq offset) output: (n offset and resudial (zpq offset)) + params.zpq_divmod(offset_n[s], residual, offset_nzpq); + // input: (zpq offset) output: (z offset and resudial (pq)) + params.pq_divmod(offset_z[s], residual, residual); + // input: (pq offset) output: (p offset and resudial (q offset)) + params.q_divmod(offset_p[s], offset_q[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_z[s], offset_p[s], offset_q[s], 0, 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + // mask predicates for filter position T + CUTLASS_PRAGMA_NO_UNROLL + for (int t = 0; t < problem_size_.T; ++t) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int t_ = t; + if (problem_size_.mode == Mode::kConvolution) { + t_ = problem_size_.T - 1 - t; + } + + int d = offset_z[s_idx] * problem_size_.stride_d - problem_size_.pad_d + t_ * problem_size_.dilation_d; + + bool pred = (offset_n[s_idx] < problem_size_.N && d >= 0 && d < problem_size_.D); + masks_[s_idx][0] |= (pred << t); + } + } + + // mask predicates for filter position R + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; + + bool pred = (h >= 0 && h < problem_size_.H); + masks_[s_idx][1] |= (pred << r); + } + } + + // mask predicates for filter position S + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; + + bool pred = (w >= 0 && w < problem_size_.W); + masks_[s_idx][2] |= (pred << s); + } + } + + if (filter_c_ >= problem_size.C) { + clear_mask(); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + /// Returns the coordinate in the activations tensor X that is correspoinding to + // output nzpq and filter position t, r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int z, int p, int q, int t, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + t = problem_size_.T - 1 - t; + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, d, h, w, filter_c_); + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask_(bool clear) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + // We are using inline PTX assembly here to avoid an CUDA C++ compilation + // artifact in which control flow instructions are generated. Instead, our + // intent is to predicate the mov instructions. + #if defined(__CUDA_ARCH__) + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][0]) + : + "r"((int)clear), + "r"(masks_[s][0]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][1]) + : + "r"((int)clear), + "r"(masks_[s][1]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][2]) + : + "r"((int)clear), + "r"(masks_[s][2]) + ); + #else + if (clear) { + masks_[s][0] = 0; + masks_[s][1] = 0; + masks_[s][2] = 0; + } + #endif + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + + filter_s_ = 0; + ++filter_r_; + next_idx = 1; + + if (filter_r_ == problem_size_.R) { + filter_r_ = 0; + ++filter_t_; + + if (filter_t_ < problem_size_.T) { + next_idx = 2; + } + else { + filter_t_ = 0; + next_idx = 3; + } + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 3) { + filter_c_ += params_.filter_c_delta; + } + + clear_mask_(filter_c_ >= problem_size_.C); + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask() { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][0] = Mask(0); + masks_[s][1] = Mask(0); + masks_[s][2] = Mask(0); + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && + (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorOptimized &operator++() { + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + // Conv3dFpropActivationTileAccessIteratorOptimized has constraint on filter positions + // due to the number of mask bits. + if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h new file mode 100644 index 00000000..eced0c8b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dFpropFilterTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dAnalyticParams; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_t_; + int filter_r_; + int filter_s_; + int filter_c_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorAnalytic( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_t_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, filter_t_, filter_r_, filter_s_, filter_c_); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h new file mode 100644 index 00000000..7fc36621 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -0,0 +1,277 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_ +> +class Conv3dFpropFilterTileAccessIteratorOptimized{ +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params : Conv3dFpropFilterIteratorOptimizedParams { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv3dFpropFilterIteratorOptimizedParams const &base): + Conv3dFpropFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): + Conv3dFpropFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { + + } + }; + +private: + + Conv3dFpropFilterIteratorOptimizedParams const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + uint32_t predicates_; + int filter_trs_; + int filter_c_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorOptimized( + Conv3dFpropFilterIteratorOptimizedParams const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_trs_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + Index column = threadblock_offset.column() + thread_coord.strided(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + predicates_ |= (pred << s); + } + + if (filter_c_ >= problem_size.C) { + predicates_ = 0u; + } + + pointer_ += ( + params_.layout({filter_c_, column}) + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_trs; + + // moves to the next tile + ++filter_trs_; + if (filter_trs_ == params_.TRS) { + + filter_trs_ = 0; + next = params_.inc_next_c; + filter_c_ += params_.filter_c_delta; + } + + if (filter_c_ >= problem_size_.C) { + predicates_ = 0; + } + + pointer_ += next; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + return (predicates_ & (1u << iteration_strided_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_k; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_params.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_params.h new file mode 100644 index 00000000..51884bc6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_params.h @@ -0,0 +1,507 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/conv2d_params.h" +#include "cutlass/conv/conv3d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Params structure used for all Conv3d analytic tile iterators +template< typename Layout_ = layout::TensorNDHWC > +struct Conv3dAnalyticParams { + + using Layout = Layout_; + + Layout layout; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dAnalyticParams() { } + + CUTLASS_HOST_DEVICE + Conv3dAnalyticParams( + Conv3dProblemSize const &, // unused; placeholder to match other Params interfaces. + Layout const &layout + ): layout(layout) { + + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized +template< typename Layout_ = layout::TensorNDHWC > +struct Conv3dFpropActivationIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized +template<> +struct Conv3dFpropActivationIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + + int64_t inc_next[4]; // {next S, next R, next T, next C} + int filter_c_delta; // number of logical elements to add to filter_c_ + int ZPQ; // product of Z*P*Q + int PQ; // product of P*Q + + FastDivmod zpq_divmod; + FastDivmod pq_divmod; + FastDivmod q_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + PQ(problem_size.P * problem_size.Q), + ZPQ(problem_size.Z * problem_size.P * problem_size.Q), + zpq_divmod(ZPQ), + pq_divmod(PQ), + q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv3d_fprop", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + + int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); + + // next S + inc_next[0] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[1]) * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next T + inc_next[2] = conv_sign * ( + int64_t(layout.stride()[2]) * problem_size.dilation_d + - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next C + inc_next[3] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template< typename Layout_ = layout::TensorNDHWC > +struct Conv3dFpropFilterIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Conv3dFpropFilterIteratorOptimizedParams +{ + + using Layout = layout::TensorNDHWC; + + Layout layout; + int TRS; + int filter_c_delta; + + int64_t inc_next_k; // offset in units of bytes to next K position + int64_t inc_next_trs; // offset in units of bytes to next TRS position + int64_t inc_next_c; // offset in units of bytes to next C position + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv3dFpropFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dFpropFilterIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_fprop", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + TRS = problem_size.T * problem_size.R * problem_size.S; + + inc_next_k = (int64_t(layout.stride()[3]) * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_trs = + ( int64_t(layout.stride()[0]) + - int64_t(layout.stride()[3]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() + ) * element_size_bits / 8; + + inc_next_c = + ( + threadblock_shape.row() * problem_size.split_k_slices + - int64_t(TRS - 1) * layout.stride()[0] + - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] + ) * element_size_bits / 8; + + filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters object for Conv3d DGRAD OutputGradient (dy) iterator +struct Conv3dDgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + + int64_t inc_next[4]; // {next S, next R, next T, next K} + int filter_k_delta; // number of logical elements to add to filter_k_ + + FastDivmod dhw_divmod; + FastDivmod hw_divmod; + FastDivmod w_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + dhw_divmod(problem_size.D * problem_size.H * problem_size.W), + hw_divmod(problem_size.H * problem_size.W), + w_divmod(problem_size.W) { + + TRACE_CONV_INITIALIZERS("conv3d_dgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); + + // next S + inc_next[0] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[1]) * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next T + inc_next[2] = conv_sign * ( + int64_t(layout.stride()[2]) * problem_size.dilation_d + - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next K + inc_next[3] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters object for Conv2d DGRAD Filter (w) iterator +struct Conv3dDgradFilterIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + int TRS; + int filter_k_delta; + + int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile + int64_t inc_next_trs; // offset in units of bytes to next TRS position + int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv3dDgradFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dDgradFilterIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), TRS(problem_size.T * problem_size.R * problem_size.S) { + + TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_trs = + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] + ) * element_size_bits / 8; + + inc_next_k = + ( + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3] + - (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] + ) * element_size_bits / 8; + + filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +/// Parameters object for Conv3d WGRAD OutputGradient iterator +struct Conv3dWgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + using LongIndex = typename Layout::LongIndex; + + Layout layout; + + int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates + int ZPQ; // product of Z*P*Q + unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ + unsigned zpq_shr; // in device code. + + int PQ; // product of P*Q + unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ + unsigned pq_shr; // in device code. + + unsigned q_mul; // precomputed quantities for fast computation of div/% by Q + unsigned q_shr; // in device code. + + LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile + LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile + LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_wgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + + offset_next_contiguous = (threadmap_delta.contiguous()) + * element_size_bits / 8; + + inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + + // Precompute several quantities for fast modulo arithmetic. + NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q; + ZPQ = problem_size.Z * problem_size.P * problem_size.Q; + find_divisor(zpq_mul, zpq_shr, ZPQ); + + PQ = problem_size.P * problem_size.Q; + find_divisor(pq_mul, pq_shr, PQ); + + find_divisor(q_mul, q_shr, problem_size.Q); + + } +}; + +/// Parameters object for Conv3d WGRAD Activation Tile Access Iterator +struct Conv3dWgradActivationIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + + int RSC; // product of R*S*C + unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC + unsigned rsc_shr; // in device code. + + int SC; // product of S*C + unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC + unsigned sc_shr; // in device code. + + unsigned c_mul; // precomputed quantities for fast computation of div/% by C + unsigned c_shr; // in device code. + + int ZPQ; // product of Z*P*Q + unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ + unsigned zpq_shr; // in device code. + + int PQ; // product of P*Q + unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ + unsigned pq_shr; // in device code. + + unsigned q_mul; // precomputed quantities for fast computation of div/% by Q + unsigned q_shr; // in device code. + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv3dWgradActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dWgradActivationIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_wgrad", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Precompute several quantities for fast modulo arithmetic. + RSC = problem_size.R * problem_size.S * problem_size.C; + find_divisor(rsc_mul, rsc_shr, RSC); + + SC = problem_size.S * problem_size.C; + find_divisor(sc_mul, sc_shr, SC); + + find_divisor(c_mul, c_shr, problem_size.C); + + ZPQ = problem_size.Z * problem_size.P * problem_size.Q; + find_divisor(zpq_mul, zpq_shr, ZPQ); + + PQ = problem_size.P * problem_size.Q; + find_divisor(pq_mul, pq_shr, PQ); + + find_divisor(q_mul, q_shr, problem_size.Q); + + } +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h new file mode 100644 index 00000000..cd064fb8 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -0,0 +1,287 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + + static int const kAccessesPerVector = 1; + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + // Filter postion (t,r,s,c) in contiguous dimension stays constant for each gemm_iteration_k + int filter_t_[ThreadMap::Iterations::kContiguous]; + int filter_r_[ThreadMap::Iterations::kContiguous]; + int filter_s_[ThreadMap::Iterations::kContiguous]; + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_nzpq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorAnalytic( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize t,r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); + int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); + + filter_r_[c] = residual / (problem_size_.S * problem_size_.C); + residual = residual % (problem_size_.S * problem_size_.C); + + filter_s_[c] = residual / problem_size_.C; + filter_c_[c] = residual % problem_size_.C; + + } + + // initialize n, z, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int t = filter_t_[iteration_contiguous_]; + int r = filter_r_[iteration_contiguous_]; + int s = filter_s_[iteration_contiguous_]; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - t); + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); + int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); + + int z = residual / (problem_size_.P * problem_size_.Q); + residual = residual % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.D && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h new file mode 100644 index 00000000..a49a6b6c --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -0,0 +1,317 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params : Conv3dWgradActivationIteratorOptimizedParams { + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Conv3dWgradActivationIteratorOptimizedParams const &base) + : Conv3dWgradActivationIteratorOptimizedParams(base) {} + + CUTLASS_HOST_DEVICE + Params(Conv3dProblemSize const &problem_size, Layout const &layout) + : Conv3dWgradActivationIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + // Precomputed effective filter postion (t,r,s) in contiguous dimension stays constant for each gemm_iteration_k + // required for nzpq -> ndhw translation + int precomputed_filter_t_[ThreadMap::Iterations::kContiguous]; + int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; + int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; + + // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_nzpq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorOptimized( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize t,r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); + // int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); + // + // filter_r_[c] = residual / (problem_size_.S * problem_size_.C); + // residual = residual % (problem_size_.S * problem_size_.C); + // + // filter_s_[c] = residual / problem_size_.C; + // filter_c_[c] = residual % problem_size_.C; + + int residual; + fast_divmod(precomputed_filter_t_[c], residual, trsc_offset, params_.RSC, params_.rsc_mul, params_.rsc_shr); + fast_divmod(precomputed_filter_r_[c], residual, residual, params_.SC, params_.sc_mul, params_.sc_shr); + fast_divmod(precomputed_filter_s_[c], filter_c_[c], residual, problem_size_.C, params_.c_mul, params_.c_shr); + + int t = precomputed_filter_t_[c]; + int r = precomputed_filter_r_[c]; + int s = precomputed_filter_s_[c]; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - t); + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + // efective t,r,s for every contiguous dimension + precomputed_filter_t_[c] = - problem_size_.pad_d + t * problem_size_.dilation_d; + precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h; + precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w; + + + } + + // initialize n, z, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); + // int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); + // + // int z = residual / (problem_size_.P * problem_size_.Q); + // residual = residual % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, z, p, q; + fast_divmod(n, residual, offset_nzpq_[iteration_strided_], params_.ZPQ, params_.zpq_mul, params_.zpq_shr); + fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); + fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); + + int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_]; + int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];; + int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; + + return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.D && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 00000000..89c29203 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,265 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_k_[ThreadMap::Iterations::kContiguous]; + + int offset_nzpq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize filter_k for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] = threadblock_offset.column() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-A by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] += Shape::kColumn * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int nzpq = offset_nzpq_[iteration_strided_]; + + int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + + int z = residual / (problem_size_.P * problem_size_.Q); + residual = residual % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + return TensorCoord(n, z, p, q, filter_k_[iteration_contiguous_]); + } + + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() < problem_size_.Z && + coord.h() < problem_size_.P && + coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 00000000..dcd526b1 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradOutputGradientTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params : Conv3dWgradOutputGradientIteratorOptimizedParams { + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Conv3dWgradOutputGradientIteratorOptimizedParams const &base) + : Conv3dWgradOutputGradientIteratorOptimizedParams(base) {} + + CUTLASS_HOST_DEVICE + Params(Conv3dProblemSize const &problem_size, Layout const &layout) + : Conv3dWgradOutputGradientIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + uint32_t predicates_; + int filter_k_; + int offset_nzpq_; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorOptimized( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_(0), + filter_k_(0), + offset_nzpq_(0) { + + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); + offset_nzpq_ = threadblock_offset.column() + thread_coord.strided(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; + int offset_nzpq = offset_nzpq_ + s * ThreadMap::Delta::kStrided; + + bool predicate = valid_(at_(offset_nzpq, filter_k)); + + uint32_t pred = (predicate ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_ |= (pred << pred_idx); + } + } + + // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) + pointer_ += ( + offset_nzpq_ * params.layout.stride()[0] + filter_k_ + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile + offset_nzpq_ += Shape::kColumn * problem_size_.split_k_slices; + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (offset_nzpq_ + s * ThreadMap::Delta::kStrided >= params_.NZPQ) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + predicates_ = (predicates_ & (~kClearMask)); + } + } + pointer_ += params_.inc_next_nzpq; + } + +private: + /// Returns the coordinate in the output gradient tensor Dy that is (offset_nzpq, k) pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at_(int offset_nzpq, int k) const { + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // int nzpq = offset_nzpq_; + // int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + // int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + // + // int z = residual / (problem_size_.P * problem_size_.Q); + // residual = residual % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, z, p, q; + fast_divmod(n, residual, offset_nzpq, params_.ZPQ, params_.zpq_mul, params_.zpq_shr); + fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); + fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); + + return TensorCoord(n, z, p, q, k); + } + + /// Returns true if the coord is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid_(TensorCoord coord) const { + + return coord.n() < problem_size_.N && + coord.c() < problem_size_.K; + } + +public: + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_ & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast( + pointer_ + + iteration_strided_ * params_.offset_next_strided + + iteration_contiguous_ * params_.offset_next_contiguous + ); + + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_direct_conv_params.h new file mode 100644 index 00000000..fbb178dd --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_direct_conv_params.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvParams; + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvFilterIteratorParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template<> +struct Depthwise2dFpropDirectConvParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int32_t activation_tile_h; + int32_t activation_tile_w; + int32_t activation_tile_hw; + FastDivmod activation_tile_w_divmod; + + int filter[2]; + int stride[2]; + int dilation[2]; + int inc_next[2]; + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_load_count; + int activation_storage_elements; + int activation_size; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams() { } + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + MatrixCoord threadblock_shape, ///< CTA threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int element_size_bits, ///< bits of activation element + const int thread_count, ///< threads per threadblock + const int thread_count_contiguous, ///< number of threads for continuous dimension + const int element_per_load) ///< element per each load + : layout(layout) { + + filter[0] = problem_size.S; + filter[1] = problem_size.R; + + stride[0] = problem_size.stride_w; + stride[1] = problem_size.stride_h; + + dilation[0] = problem_size.dilation_w; + dilation[1] = problem_size.dilation_h; + + // Compute activation_tile size per threadblock because stride and dilation are runtime params. + activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h + + (problem_size.R - 1) * problem_size.dilation_h + 1; + activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w + + (problem_size.S - 1) * problem_size.dilation_w + 1; + activation_tile_hw = activation_tile_h * activation_tile_w; + + activation_tile_w_divmod = FastDivmod(activation_tile_w); + + /// Below two values could not be templatized because the stride and dilation are runtime params + activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count; + activation_storage_elements = activation_load_count * element_per_load * thread_count; + activation_size = activation_storage_elements * element_size_bits / 8; + + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) / + (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + + // next S + inc_next[0] = problem_size.dilation_w; + // next R + inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template <> +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_size; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int activation_size_ ///< Activation size loaded by iterator + ) + : layout(layout), + activation_size(activation_size_) { + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = + (problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template <> +struct Depthwise2dFpropDirectConvFilterIteratorParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + int filter_size; + + bool is_convolution; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + const int filter_size_) ///< Filter size loaded by iterator + : layout(layout), + filter_size(filter_size_), + is_convolution(problem_size.mode == Mode::kConvolution){} +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h new file mode 100644 index 00000000..92024181 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h @@ -0,0 +1,314 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + // Compilation value of stride , dialtion and activation shape + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ActivationShape = ActivationShape_; + + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // activation N is inside the Tensor or not + bool valid_n_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0) { + + base_c_ = threadblock_offset.column(); + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + + valid_n_ = activatioin_base_.n() < problem_size_.N; + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + kActivationSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h = next / ActivationShape::kW; + int w = next % ActivationShape::kW; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + bool valid_c = coord.c() < problem_size_.C; + bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H; + bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W; + return valid_n_ ? valid_c & valid_h & valid_w : 0; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < ThreadMap::Iterations::kCount) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kActivationSize; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return ThreadMap::Iterations::kCount; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check stride and dilation constraint + if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 00000000..1337a249 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,291 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // Number of loads for activations tensor X. + const int number_of_loads_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0), + number_of_loads_(params.activation_load_count) { + + base_c_ = threadblock_offset.column(); + + set_activation_coord(offset_intial_npq_); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + sizeof_bits::value, + ThreadMap::kThreads, + ThreadMap::Detail::ShapeVec::kContiguous, + ThreadMap::kElementsPerAccess); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h, w; + params_.activation_tile_w_divmod(h, w, next) ; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < number_of_loads_) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return params_.activation_size; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return number_of_loads_; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h new file mode 100644 index 00000000..01955b0a --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h @@ -0,0 +1,551 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Epilogue stores the data into global memory + typename Epilogue_, + /// iterator implementation variants + conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseFpropDirectConvMultipleStage : + public DepthwiseDirectConvMmaBase { +public: + ///< Base class + using Base = DepthwiseDirectConvMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Policy describing tuning details + using Policy = Policy_; + + using Epilogue = Epilogue_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropDirectConvMultipleStage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorB &iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + // Number of iterators is a static value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA &iterator_A, + ///< Params of global memory iterator + typename IteratorA::Params const &iterator_a_params, + ///< iterator over B operand in global memory + IteratorB &iterator_B, + ///< Params of global memory iterator + typename IteratorB::Params const &iterator_b_params, + ///< initial value of accumulator + FragmentC const &src_accum, + /// Epilogue + Epilogue &epilogue, + ///< Output operator + typename Epilogue::OutputOp const &output_op, + ///< Tile iterator for destination + typename Epilogue::OutputTileIterator &destination_iterator, + ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + typename Epilogue::OutputTileIterator &source_iterator, + + int split_k_slices = 1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + if (stage == 0) { + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + } + + if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){ + // Number of iterators is compilation static. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num()); + this->smem_iterator_A_.set_iteration_index(0); + + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + // Move to the next stage + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + ///////////////////////////////////////////////////////////////////////////// + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // + // Mainloop + // + + unsigned int iterations = 0; + constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2); + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile. + + accum.clear(); + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) { + if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k == 0) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + copy_tiles_and_advance( + iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k < Base::kWarpGemmIterations) { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k + 1 == inner_loop_iterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == inner_loop_iterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next cta + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0}); + + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size()); + smem_read_stage_idx = 0; + } else { + this->warp_tile_iterator_A_.advance(iterator_A.get_load_size()); + ++smem_read_stage_idx; + } + + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + } + + // goback to start position. B has no multiple stage + this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0}); + + --gemm_k_iterations; + } + } + + // + // Epilogue + // + int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size(); + + destination_iterator.set_tile_index(iterations * split_k_slices); + + source_iterator.set_tile_index(iterations * split_k_slices); + + epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset); + + ++iterations; + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 00000000..6a698394 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +template > +class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized { +public: + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + using Params = Depthwise2dFpropDirectConvFilterIteratorParams; + + protected: + + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int offset_trs_[ThreadMap::Iterations::kStrided]; + +public: + + + + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Do nothing because the filter is persistent in the SMEM + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + int trs = offset_trs_[iteration_strided_]; + + return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && + coord.h() < Shape::kColumn; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + int64_t offset = coord.n(); + if (params_.is_convolution) { + offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K; + } else { + offset += coord.h() * problem_size_.K; + } + + return reinterpret_cast(pointer_ + + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines the filter size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kFilterSize; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + // check whether runtime filter size is same as templated filter size. + if ((problem_size.R * problem_size.S) != Shape::kColumn) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_pipelined.h new file mode 100644 index 00000000..9e41a17e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_fprop_pipelined.h @@ -0,0 +1,336 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to A operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::MmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropPipelined( + typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum, ///< source accumulator tile + int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + // Depthwise specific + int channel_start_index = 0; + int rs_plane_idx = 0; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Reset interation index. + iterator_B.set_iteration_index(0); + } + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Move to next set of filter groups. + channel_start_index += Base::kWarpGemmIterations; + } + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); + } + + rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1); + + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_mma_base.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_mma_base.h new file mode 100644 index 00000000..7e4ac5b0 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_mma_base.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a directconv threadblock-scoped Depthwise kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy object describing MmaTensorOp +template < + /// Warp-level GEMM operator (concept: gemm::warp::Mma) + typename Operator_, + /// Padding used for A operand in shared memory (concept: MatrixShape) + typename SmemPaddingA_, + /// Padding used for B operand in shared memory (concept: MatrixShape) + typename SmemPaddingB_, + /// + typename ThreadMapA_, + /// + typename ThreadMapB_, + /// Number of partitions of K dimension of GEMM + int PartitionsK = 1> +struct DepthwiseDirectConvMmaPolicy { + /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) + using Operator = Operator_; + + /// Padding used for A operand in shared memory + using SmemPaddingA = SmemPaddingA_; + + /// Padding used for B operand in shared memory + using SmemPaddingB = SmemPaddingB_; + + using ThreadMapA = ThreadMapA_; + using ThreadMapB = ThreadMapB_; + + /// Number of partitions of K dimension + static int const kPartitionsK = PartitionsK; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseDirectConvMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm:: + GemmShape; + + /// Number of warp-level GEMM oeprations + /// kWarpGemmIterations could be even and odd. + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape<1, // Not determined at compile-time :( + Shape::kN + Policy::SmemPaddingA::kRow>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; // Tile N = 64? + + public: + // + // Data members + // + + // Let persistent B matrix in front of dynamic matrix A + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for A operand + /// Not be determined at compile-time -- Just to get a Smem start address. + AlignedBuffer operand_A; + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseDirectConvMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h new file mode 100644 index 00000000..49ad555a --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h @@ -0,0 +1,952 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data + layout of the global memory fragments, data types, and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/warp/mma_depthwise_simt.h" + +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_singlestage.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h" + +#include "cutlass/arch/cache_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +namespace detail { +// +// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by +// each partitions within warp. +// The goal is for each thread's tile of elements to be as square as +// possible for performance (4x4 will be faster than 2x8). +template // The number of partitions within the warp +struct SimtWarpShape { + // kP * kQ * WarpNumThreadsM = WarpShapeM + // If needed, enable more specializations. +}; +template <> +struct SimtWarpShape<4, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<8, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<8, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; +template <> +struct SimtWarpShape<8, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<16, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template +struct SimtWarpShape<25, WarpNumThreadsM> { + static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads"); + static constexpr int kP = 5; + static constexpr int kQ = 5; +}; + +template <> +struct SimtWarpShape<32, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 8; +}; + +template <> +struct SimtWarpShape<32, 2> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; + +template <> +struct SimtWarpShape<32, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; + +} // namespace detail + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of threadblock-scoped output tile + typename ThreadBlockOutputShape, + /// Shape of filter shape per threadblock + typename FilterShape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// per-element transformation for elements of A + ComplexTransform TransformA, + /// per-element transformation for elements of B + ComplexTransform TransformB, + bool IsComplex +> +struct DepthwiseMmaCoreWithLaneAccessSize< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> : cutlass::gemm::threadblock::DefaultMmaCore< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeA_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeB_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + kLaneAccessSizeB_, + 2, + Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_> { + using Base = cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_>; + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeA = kLaneAccessSizeA_; + static int const kLaneAccessSizeB = kLaneAccessSizeB_; + + // Divisility requirements + static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = typename Base::WarpCount; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory are same as base class + // + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m(); + static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = kLaneAccessSizeA / sizeof_bits::value; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + + static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + WarpCount::kK + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, // Set kStrided = 1 because activation shape is runtime value. + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + true // Dynamic iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + // Get output P, Q per thread + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape_, + DilationShape_, + ActivationShape_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + using ActivationShape = ActivationShape_; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + false // static iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type + StrideShape, /// Stride ( MatrixShape ) + DilationShape, /// Dilation ( MatrixShape ) + ActivationShape /// Activation Shape loaded by threadblock + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; +} // namespace threadblock +} // namespace conv +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h new file mode 100644 index 00000000..c57f0676 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h @@ -0,0 +1,802 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped fused activation's + scale+bias+relu and Implicit GEMM Convolution kernel. + + The original implicit gemm will store out-of-bound data as zeroes in the + shared memory because zeros into the tensor core, zeroes out of the tensor + cores. The result is remained the same. When fusing scale+bias+relu + into the mainloop, it is no longer true because + + 0 x scale + bias = bias + + which is no longer always 0. So, instead of storing zeroes, this fused + kernel stores the out-of-bound data as a special NaN (0x7eff), when applying + scale+bias+relu, the code is like + + if (data == 0x7eff) + data = 0; + else + data = scale+bias+relu(data, scale, bias); + + See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the + elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" +#include "cutlass/conv/warp/scale_bias_relu_transform.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Element type of scale and bias vectors + typename ElementScaleBias_, + /// Layout of scale and bias vectors + typename LayoutScaleBias_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// WarpIterator to load Scale or Bias vector from the shared memory + typename WarpIteratorScaleBias_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaFpropFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Element type of scale and bias vectors + using ElementScaleBias = ElementScaleBias_; + + /// Layout of scale and bias vectors + using LayoutScaleBias = LayoutScaleBias_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< WarpIterator to load Scale or Bias vector from the shared memory + using WarpIteratorScaleBias = WarpIteratorScaleBias_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm::GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the scale and bias vectors + using TensorRefScaleBias = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the A scale and bias vectors in shared memory + using ShapeScaleBias = + MatrixShape<1 + Policy::SmemPaddingA::kRow, + 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for A operand Scale and Bias + AlignedBuffer operand_A_scale_bias; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a layout object for the A scale and bias vectors + CUTLASS_DEVICE + static LayoutScaleBias LayoutScaleBias() { + return LayoutScaleBias::packed( + {ShapeScaleBias::kRow, ShapeScaleBias::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the A operand Scale vector + CUTLASS_HOST_DEVICE + TensorRefScaleBias operand_A_scale_bias_ref() { + return TensorRefScaleBias{operand_A_scale_bias.data(), LayoutScaleBias()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A operand scale and bias vector + /// from shared memory + WarpIteratorScaleBias warp_tile_iterator_A_scale_bias_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaFpropFusionBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_A_scale_bias_( + shared_storage.operand_A_scale_bias_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorScaleBias_, + /// Iterates over vectors of scale and bias vector in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorScaleBias_, + /// Cache operation for scale/bias operand + cutlass::arch::CacheOperation::Kind CacheOpScaleBias, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// WarpIterator to load Scale or Bias vector from the shared memory + typename WarpIteratorScaleBias_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class ImplicitGemmFpropFusionMultistage + : public MmaFpropFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorScaleBias = IteratorScaleBias_; + ///< WarpIterator to load Scale or Bias vector from the shared memory + using WarpIteratorScaleBias = WarpIteratorScaleBias_; + ///< Policy describing tuning details + using Policy = Policy_; + ///< Base class + using Base = MmaFpropFusionBase; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScaleBias = SmemIteratorScaleBias_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + static cutlass::arch::CacheOperation::Kind const kCacheOpScaleBias = + CacheOpScaleBias; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpLoadedFragmentScaleBias = + typename WarpIteratorScaleBias::Fragment; + + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory + SmemIteratorScaleBias smem_iterator_A_scale_bias_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmFpropFusionMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_A_scale_bias_(shared_storage.operand_A_scale_bias_ref(), + thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_bias_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorScaleBias &iterator_A_scale_bias, + IteratorB &iterator_B, int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + // Uses nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + + ++this->smem_iterator_A_; + } + } + + // Async Copy for operand A scale and bias vector. Scale and bias vectors + // are small. One iteration is enough. + if (group_start_A == 0) { + typename IteratorScaleBias::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_scale_bias_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorScaleBias::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); + } + + iterator_B.set_iteration_index(group_start_B); + + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale and bias vectors in global memory + IteratorScaleBias iterator_A_scale_bias, + ///< initial value of accumulator + FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + // Uses Nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + ++this->smem_iterator_A_; + } + + // Async Copy for operand A scale and bias vectors. Scale and bias + // vectors are small. One iteration is enough. + { + typename IteratorScaleBias::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_scale_bias_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorScaleBias::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.advance(); + iterator_A_scale_bias.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpLoadedFragmentScaleBias warp_loaded_frag_A_scale_bias[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + cutlass::conv::warp::FpropScaleBiasReluTransform + elementwise_transform; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_scale_bias_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_A_scale_bias_.load( + warp_loaded_frag_A_scale_bias[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_bias_; + ++this->warp_tile_iterator_B_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + elementwise_transform(warp_transformed_frag_A[0], + warp_loaded_frag_A_scale_bias[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_scale_bias_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_bias_.load( + warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_bias_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_loaded_frag_A_scale_bias[warp_mma_k % 2]); + } + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } else { + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + elementwise_transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.advance(); + iterator_A_scale_bias.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_A_scale_bias_.add_tile_offset( + {0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_A_scale_bias_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_multistage.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_multistage.h new file mode 100644 index 00000000..7cf2ca0f --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -0,0 +1,539 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class ImplicitGemmMultistage : + public gemm::threadblock::MmaBase { +public: + ///< Base class + using Base = gemm::threadblock::MmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA &iterator_A, IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A, iterator_B); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (Detail::kStagedAccumulation) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } else { + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + + if (Detail::kStagedAccumulation) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum + ); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + } + } + + } + + if (Detail::kStagedAccumulation) { + accum = plus_accum(accum, tmp_accum); + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_pipelined.h new file mode 100644 index 00000000..f92280f2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_pipelined.h @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to A operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class ImplicitGemmPipelined : public gemm::threadblock::MmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::MmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmPipelined( + typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum, ///< source accumulator tile + int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h new file mode 100644 index 00000000..3f3ab74e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h @@ -0,0 +1,729 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped fused activation's scale+bias+relu and + Implicit GEMM Convolution kernel. + + The original implicit gemm will store out-of-bound data as zeroes in the + shared memory because zeros into the tensor core, zeroes out of the tensor + cores. The result is remained the same. When fusing scale+bias+relu + into the mainloop, it is no longer true because + + 0 x scale + bias = bias + + which is no longer always 0. So, instead of storing zeroes, this fused + kernel stores the out-of-bound data as a special NaN (0x7eff), when applying + scale+bias+relu, the code is like + + if (data == 0x7eff) + data = 0; + else + data = scale+bias+relu(data, scale, bias); + + The biggest difference compared with the fused Fprop and scale+bias+relu is + that scale and bias are loop invariant in Wgrad so that they only needs to + be loaded once before the mainloop. + + See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the + elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. + + +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" +#include "cutlass/conv/warp/scale_bias_relu_transform.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Element type of scale and bias vectors + typename ElementScaleBias_, + /// Layout of scale and bias vectors + typename LayoutScaleBias_, + /// Element type of scale and bias vectors + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaWgradFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Element type of scale and bias vectors + using ElementScaleBias = ElementScaleBias_; + + /// Layout of scale and bias vectors + using LayoutScaleBias = LayoutScaleBias_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm::GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaWgradFusionBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorScaleBias_, + /// Iterates over vectors of scale and bias vector i + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class ImplicitGemmWgradFusionMultistage + : public MmaWgradFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorScaleBias = IteratorScaleBias_; + ///< Policy describing tuning details + using Policy = Policy_; + ///< Base class + using Base = MmaWgradFusionBase; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const kBBufferSize = + ((sizeof(typename Operator::ElementC) == 4) && + ((platform::is_same::value && + platform::is_same::value)) && + (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) + ? 1 + : 2; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpLoadedFragmentScaleBias = typename IteratorScaleBias::Fragment; + + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + int warp_idx_m_; + + int warp_idx_n_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmWgradFusionMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; + warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + + iterator_A.set_iteration_index(group_start_A); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B); + + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + // Uses nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale and bias vectors in global memory + IteratorScaleBias iterator_B_scale_bias, + ///< initial value of accumulator + FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + WarpLoadedFragmentScaleBias warp_loaded_frag_B_scale_bias; + iterator_B_scale_bias.add_tile_offset({0, warp_idx_n_}); + iterator_B_scale_bias.load(warp_loaded_frag_B_scale_bias); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + // Uses Nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[Detail::kBBufferSize]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[Detail::kBBufferSize]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + cutlass::conv::warp::WgradScaleBiasReluTransform + elementwise_transform; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A, iterator_B); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + elementwise_transform(warp_transformed_frag_B[0], + warp_loaded_frag_B_scale_bias); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (Detail::kBBufferSize == 2) { + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize]); + ++this->warp_tile_iterator_A_; + } + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % Detail::kBBufferSize], + warp_loaded_frag_B[warp_mma_k % 2]); + + elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_B_scale_bias); + } + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + + if (Detail::kBBufferSize == 1) { + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + ++this->warp_tile_iterator_A_; + + } + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + elementwise_transform( + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_B_scale_bias); + } + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } else { + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance(iterator_A, iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h new file mode 100644 index 00000000..acb36e17 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -0,0 +1,470 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorAccessIterator +/// +template +class PredicatedScaleBiasVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; + + using AccessType = AlignedArray; + + using Params = PredicatedScaleBiasVectorAccessIteratorParams; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + int problem_size_trs; + int problem_size_c; + int filter_trs_; + + TensorCoord thread_offset_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_trs(problem_size.R * problem_size.S), + problem_size_c(problem_size.C), + filter_trs_(0) { + pointer_ = (thread_id < kThreads) + ? reinterpret_cast( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv3dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_trs(problem_size.T * problem_size.R * problem_size.S), + problem_size_c(problem_size.C), + filter_trs_(0) { + pointer_ = (thread_id < kThreads) + ? reinterpret_cast( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv3dProblemSize const &problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + thread_offset_ = + thread_offset_ + + TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + return reinterpret_cast( + pointer_ + + (thread_offset_.contiguous() * sizeof_bits::value / 8)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_trs_; + if (filter_trs_ == problem_size_trs) { + filter_trs_ = 0; + add_tile_offset(TensorCoord(1, 0)); + } + } + + /// Increment and return an instance to self. + CUTLASS_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + uint32_t enabled = 0; + +#if defined(_MSC_VER) || (__CUDACC_VER_MAJOR__ < 11) + enabled = threadIdx.x < kThreads * 2; +#else + asm volatile( + "{\n" + " .reg .u32 tid_reg;\n" + " .reg .pred p;\n" + " mov.u32 tid_reg, %%tid.x;\n" + " setp.lt.u32 p, tid_reg, %1;\n" + " selp.u32 %0, 1, 0, p;\n" + "}\n" : "+r"(enabled) :"n"(kThreads * 2)); +#endif + + return ((thread_offset_.contiguous() < problem_size_c) && enabled); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + + using Params = PredicatedScaleBiasVectorAccessIteratorParams; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Extent of tensor + Conv2dProblemSize const &problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params, problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Extent of tensor + Conv3dProblemSize const &problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params, problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Conv2dProblemSize const &problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Conv3dProblemSize const &problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + iterator_.advance(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h new file mode 100644 index 00000000..3d155f8e --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h @@ -0,0 +1,371 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorIterator +/// +template +class PredicatedScaleBiasVectorIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 1; + + using AccessType = AlignedArray; + + static int const kIterations = WarpShape::kContiguous / 8; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; + + /// Parameters object is precomputed state and is host-constructible + using Params = Conv2dWgradActivationIteratorOptimizedParams; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + ConstPointer scale_pointer_; + ConstPointer bias_pointer_; + + /// Size of tensor + Conv2dProblemSize problem_size_; + + int32_t thread_offset_; + + // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k + int32_t filter_c_[kIterations]; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_(problem_size), + scale_pointer_(scale_pointer), + bias_pointer_(bias_pointer) { + + thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; + } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); + + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < kIterations; ++c) { + int rsc_offset = thread_offset_ + c * 8; + + int residual, tmp; + params_.sc_divmod(tmp, residual, rsc_offset); + params_.c_divmod(tmp, filter_c_[c], residual); + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.fill(__float2half2_rn(0.0f)); + __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); + + // load scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2].x, + scale_pointer_ + filter_c_[c], + true + ); + } + + // load bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2 + 1].x, + bias_pointer_ + filter_c_[c], + true + ); + } + + // duplicate scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2].y = frag_ptr[c * 2].x; + } + + // duplicate bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + using Fragment = typename UnderlyingIterator::Fragment; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedScaleBiasVectorIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Conv2dProblemSize const &problem_size, Layout const &layout) + : params_(problem_size, layout::TensorNHWC(0, 0, 0)){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Extent of tensor + Conv2dProblemSize const &problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + Params const ¶ms, ///< Precomputed parameters object + Conv2dProblemSize const &problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + iterator_.load(frag); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/threadblock_swizzle.h b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/threadblock_swizzle.h new file mode 100644 index 00000000..5e0798a1 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/threadblock/threadblock_swizzle.h @@ -0,0 +1,193 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several possible threadblock-swizzling functions mapping blockIdx to + Convolution problems. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/platform/platform.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_HOST_DEVICE +static int get_strided_dgrad_tile_m( + cutlass::conv::Conv2dProblemSize const &problem_size, + int tile_size_m) { + + // CTAs in M dimension per starting filter position + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m); + + // Inflate number of CTAs in M dimension to cover every strating filter position even those that + // may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source) + // and point-wise fusion + int tile_m = tile_m_per_filter * int(problem_size.stride().product()); + + // There is a possible performance optimization here that leads up to 2x speeds than the current + // CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1}) + // + // * Optimization * + // Only launch CTAs in M dimension which contribute to a row in Dx output + // + // + // * Constraints * + // (A) stride <= filter, for example, stride={2x2} and filter={3x3}: + // - (A.1): There are no constraints for this case and the optimization does + // affect this case functionality or performance. + // (B) stride > filter, for example, stride={2x2} and filter={1x1}: + // - (B.1): Dx output tensor should be zero initialized + // - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero + + return tile_m; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for strided dgrad convolution +struct StridedDgradHorizontalThreadblockSwizzle : + public gemm::threadblock::GemmHorizontalThreadblockSwizzle { + + using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle; + + CUTLASS_HOST_DEVICE + StridedDgradHorizontalThreadblockSwizzle() { } + + /// Returns the shape of the problem in units of logical tiles + /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + CUTLASS_HOST_DEVICE + static gemm::GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + // compute number of tiles in m dimension + int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); + + // compute number of tiles in n dimension + int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); + + return gemm::GemmCoord( + tile_m, + tile_n, + split_k_slices); + } + + /// Returns the shape of the problem in units of logical tiles + /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) + private: + using Base::get_tiled_shape; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for strided dgrad convolution +template +struct StridedDgradIdentityThreadblockSwizzle : + public gemm::threadblock::GemmIdentityThreadblockSwizzle { + + using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle; + + CUTLASS_HOST_DEVICE + StridedDgradIdentityThreadblockSwizzle() { } + + /// Returns the shape of the problem in units of logical tiles + /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + CUTLASS_HOST_DEVICE + static gemm::GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + // compute number of tiles in m dimension + int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); + + // compute number of tiles in n dimension + int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); + + return gemm::GemmCoord( + tile_m, + tile_n, + split_k_slices); + } + + /// Returns the shape of the problem in units of logical tiles + /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) + private: + using Base::get_tiled_shape; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Threadblock swizzling function for GEMMs +template +struct DepthwiseDirect2dConvIdentityThreadblockSwizzle + : public gemm::threadblock::GemmIdentityThreadblockSwizzle { + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvIdentityThreadblockSwizzle() {} + + /// Returns the shape of the problem in units of logical tiles + CUTLASS_HOST_DEVICE + static gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + return gemm::GemmCoord(1, + (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(), + split_k_slices); + } +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/warp/mma_depthwise_simt.h b/server/punica_kernels/include/cutlass/cutlass/conv/warp/mma_depthwise_simt.h new file mode 100644 index 00000000..ccf0ede7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/warp/mma_depthwise_simt.h @@ -0,0 +1,380 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/thread/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/thread/depthwise_mma.h" + + +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" + +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseSimt + : public cutlass::gemm::warp:: + MmaSimt { + using Base = cutlass::gemm::warp:: + MmaSimt; + +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + +public: + + /// Iterates over the B operand in memory + using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseSimt():Base() {} +}; + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + typename FilterShape_, + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + typename ThreadOutputShape_, + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape_ = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape_ = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseDirectConvSimt { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + using FilterShape = FilterShape_; + + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Iterator algo type + static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || + platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && + platform::is_same< ElementA, int8_t >::value && + platform::is_same< ElementB, int8_t >::value; + + using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; + + /// Thread-level matrix multiply accumulate operator + using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct< + cutlass::gemm::GemmShape< + Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread + Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread + 1>, + ElementA, + ElementB, + ElementC, + arch::OpMultiplyAdd, + dp4a_type + >; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Shape of the underlying instruction + using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>; + +public: + + /// Iterates over the A operand in memory + using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator< + MatrixShape, // per warp + FilterShape, + ThreadOutputShape, + ThreadBlockOutputShape, + cutlass::gemm::Operand::kA, + ElementA, + Policy, + IteratorAlgorithm, + StrideShape, + DilationShape, + ActivationShape, + PartitionsK, + Shape::kK + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + + /// Iterates over the B operand in memory + using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape<1, Shape::kN>, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + + /// Iterates over the C operand in memory + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kC, + ElementC, + LayoutC, + Policy + >; + + /// Storage for C tile + using FragmentC = typename ThreadMma::FragmentC; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseDirectConvSimt() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &d, + FragmentA a, + FragmentB b, + FragmentC const &c, int group_idx = 0) const { + + ThreadMma mma; + + mma(d, a, b, c); + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + dst_A = A; + dst_B = B; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace conv +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/server/punica_kernels/include/cutlass/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h new file mode 100644 index 00000000..a8398595 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h @@ -0,0 +1,862 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT + instructions +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/conv/convolution.h" + +#include "cutlass/arch/memory_sm75.h" + +#include "cutlass/layout/matrix.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions +/// +/// concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1 +> +class DepthwiseMmaSimtTileIterator; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization for B operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseMmaSimtTileIterator + : public cutlass::gemm::warp::MmaSimtTileIterator { + + using Base = cutlass::gemm::warp::MmaSimtTileIterator; + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = typename Base::TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Thread-level shape of a fragment + using ThreadShape = typename Base::ThreadShape; + + /// Number of individual loads + using Iterations = typename Base::Iterations; + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + + static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim"); + +private: + + MatrixCoord lane_offset_; + int channel_idx_; + int base_channel_idx_; + int warps_n_; + + public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator():Base() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator( + TensorRef ref, + int lane_id + ) : Base(ref, lane_id) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + warps_n_ = -1; + channel_idx_ = 0; + base_channel_idx_ = 0; + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + + if(warps_n_ == -1){ + warps_n_ = coord.column(); + } + + Base::add_tile_offset(coord); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < Iterations::kRow; ++k) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + + void const *ptr = this->ref_.data() + + this->ref_.offset({-(channel_idx_ - base_channel_idx_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + + // Base_k of a warp + Base_k of current threads. + int thread_k_base_idx = + warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column(); + + if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) { + // Depthwise kernel would only do computation when channel == k. + // Loads an element when the current computation channel == the k corresponding to this thread. + arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); + } else { + // Reduce SMEM load + dst_ptr[n + k * Iterations::kColumn].fill(Element(0)); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + if(k_group % PartitionGroupSize == 0 && k_group != 0){ + base_channel_idx_ = k_group; + } + channel_idx_ = k_group; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadBlockOutputShape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1> +class DepthwiseDirect2dConvSimtTileIterator; + + +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm, + /// Stride ( MatrixShape ) + typename StrideShape, + /// Dilation ( MatrixShape ) + typename DilationShape, + /// Activation Shape loaded by threadblock + typename ActivationShape, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + +// Thread-level shape of a fragment + using ThreadShape = MatrixShape< + ThreadOutputShape::kNHW, // Output tile shape Computed by current threads + ThreadOutputShape::kC + >; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = MatrixShape< + ThreadShape::kRow, + ThreadShape::kColumn / Policy::LaneMmaShape::kN + >; + + using ThreadTileCount = MatrixShape< + ThreadBlockOutputShape::kH / ThreadOutputShape::kH, + ThreadBlockOutputShape::kW / ThreadOutputShape::kW + >; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + +protected: + + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + int iterator_offset_; + + int inc_next_s_ ; + int inc_next_r_ ; + + MatrixCoord lane_offset_; +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator( + TensorRef ref, + int lane_id + ) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + iterator_offset_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE + void setup_initial_status(Params const& params) { + + inc_next_s_ = params.inc_next[0]; + inc_next_r_ = params.inc_next[1]; + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_p_ = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; + int base_q_ = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int base_w = (base_q_ + q) * params.stride[0]; + int base_h = (base_p_ + p) * params.stride[1]; + + int offset = base_h * params.activation_tile_w + base_w; + activation_offset[p][q][col] = offset; + } + } + } + } + + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + iterator_offset_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + iterator_offset_ += inc_next_s_; + + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + iterator_offset_ += inc_next_r_; + return *this; + } + + iterator_r_ = 0; + iterator_offset_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator & operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + Array *dst_ptr = + reinterpret_cast *>(&frag); + + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + void const *ptr = ref_.data() + + ref_.offset({activation_offset[p][q][n] + (iterator_offset_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged " + "along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, + "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + + // Activations loaded by threadblock + static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ThreadActivationShape = cutlass::conv:: + TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>; + + // Thread-level shape of a fragment + using ThreadShape = + MatrixShape; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = + MatrixShape; + + using ThreadTileCount = MatrixShape; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + + protected: + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + Array + activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + + + MatrixCoord lane_offset_; + + public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) { + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE void setup_initial_status( + Params const ¶ms) { + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_h = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow; + int base_w = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < ThreadActivationShape::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < ThreadActivationShape::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int offset = (base_h + h) * ActivationShape::kW + (base_w + w); + + void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn}); + arch::shared_load(activation[h][w][col], ptr); + } + } + } + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = + MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + return *this; + } + + iterator_r_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow; + const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn; + + dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n]; + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + +} // namespace warp +} // namespace conv +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/conv/warp/scale_bias_relu_transform.h b/server/punica_kernels/include/cutlass/cutlass/conv/warp/scale_bias_relu_transform.h new file mode 100644 index 00000000..6144cff6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/conv/warp/scale_bias_relu_transform.h @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level per channel scale+bias+relu before + matrix multiply-accumulate operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FpropScaleBiasReluTransform { + + using T = typename FragmentActivations::Element; + + static int const NumActivations = FragmentActivations::kElements; + static int const NumScaleBias = FragmentScaleBias::kElements; + static int const MmaElements = 2; + // One element has one scale and one bias + static int const MmaScaleBiasPair = 2; + // 16816 has 2 columns + static int const MmaCols = 2; + + using MmaOperand = Array; + using ScaleBiasOperand = Array; + + CUTLASS_DEVICE + void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t *ptr_activations = reinterpret_cast(&activations); + uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); + + // Apply per channel scale+bias+relu if the data is not a special NaN + // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. + + // We assumes the pair of FP16 are either both inbound or both out-of-bound. + // It requires C to be an even number. + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " .reg .b32 t1;\n\t" + " setp.eq.u32 %%p, %2, %4;\n\t" + " fma.rn.f16x2.relu t1, %1, %2, %3;\n" + " selp.u32 %0, 0, t1, %%p;\n\t" + "}\n" + : "=r"(ptr_activations[0]) + : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); +#else + assert(0); +#endif + } + + CUTLASS_DEVICE + void operator()(FragmentActivations &activations, + FragmentScaleBias const &scale_bias) { + MmaOperand *ptr_activations = reinterpret_cast(&activations); + ScaleBiasOperand const *ptr_scale_bias = + reinterpret_cast(&scale_bias); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < (NumActivations / MmaElements); ++i) { + transform(ptr_activations[i], ptr_scale_bias[(i / MmaScaleBiasPair) % MmaCols]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct WgradScaleBiasReluTransform { + + using T = typename FragmentActivations::Element; + + static int const NumActivations = FragmentActivations::kElements; + static int const NumScaleBias = FragmentScaleBias::kElements; + static int const MmaElements = 2; + // One element has one scale and one bias + static int const MmaScaleBiasPair = 2; + // 16816 has 2 rows + static int const MmaRows = 2; + + using MmaOperand = Array; + using ScaleBiasOperand = Array<__half2, MmaScaleBiasPair>; + + CUTLASS_DEVICE + void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + __half2 *ptr_activations = reinterpret_cast<__half2 *>(&activations); + uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); + +#if 1 + // CUDA + PTX version + + bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16); + bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16); + + // Apply per channel scale+bias+relu if the data is not a special NaN + // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. + + // We cannot gurantee that the pair of F16 are both in bound or both + // out-of-bound because C x R x S can be an odd number. + asm volatile( + "{\n\t" + " fma.rn.f16x2.relu %0, %1, %2, %3;\n" + "}" + : "=r"(reinterpret_cast(ptr_activations[0])) + : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), + "r"(ptr_scale_bias[1])); + + reinterpret_cast(ptr_activations[0]) = h1_oob ? + (reinterpret_cast(ptr_activations[0]) & 0xffff0000) : + reinterpret_cast(ptr_activations[0]); + + reinterpret_cast(ptr_activations[0]) = h2_oob ? + (reinterpret_cast(ptr_activations[0]) & 0xffff) : + reinterpret_cast(ptr_activations[0]); +#else + // pure PTX version + + // Apply per channel scale+bias+relu if the data is not a special NaN + // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. + asm volatile( + "{\n" + " .reg .b16 t1, t2;\n" + " .reg .b32 t3, t4, t5, t6;\n" + " .reg .pred p1, p2;\n" + " mov.b32 {t1, t2}, %2;\n" + " setp.eq.s16 p1, t1, %4;\n" + " setp.eq.s16 p2, t2, %4;\n" + " fma.rn.f16x2.relu t3, %1, %2, %3;\n" + " and.b32 t4, t3, %5;\n" + " selp.b32 t5, t4, t3, p1;\n" + " and.b32 t6, t5, %6;\n" + " selp.b32 %0, t6, t5, p2;\n" + "}\n" + : "=r"(reinterpret_cast(ptr_activations[0])) + : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); +#endif +#else + assert(0); +#endif + } + + CUTLASS_DEVICE + void operator()(FragmentActivations &activations, + FragmentScaleBias const &scale_bias) { + MmaOperand *ptr_activations = reinterpret_cast(&activations); + ScaleBiasOperand const *ptr_scale_bias = + reinterpret_cast(&scale_bias); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < (NumActivations / MmaElements); ++i) { + transform(ptr_activations[i], ptr_scale_bias[(i / MmaRows)]); + } + } +}; +} // namespace warp +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/coord.h b/server/punica_kernels/include/cutlass/cutlass/coord.h new file mode 100644 index 00000000..d4fcffe3 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/coord.h @@ -0,0 +1,489 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix +*/ + +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically-sized array specifying Coords within a tensor +template < + int Rank_, ///< Logical rank of coordinate + typename Index_ = int, ///< Index type used for each dimension + typename LongIndex_ = int64_t ///< Long index type used for linear offsets +> +struct Coord { + +public: + + // + // Type and constant definitions + // + + /// Number of elements in Coord + static int const kRank = Rank_; + + /// Index type used to store elements + using Index = Index_; + + /// Type used to represent linear offsets + using LongIndex = LongIndex_; + +private: + + // + // Data members + // + + /// Indices + Index idx[kRank]; + +public: + + // + // Methods + // + + /// Default ctor initializes uniformly + CUTLASS_HOST_DEVICE + explicit Coord(Index value = Index(0)) { + for (int i = 0; i < kRank; ++i) { + idx[i] = value; + } + } + + /// Constructs from an array of integers + CUTLASS_HOST_DEVICE + Coord(Index const (&_idx)[kRank]) { + for (int i = 0; i < kRank; ++i) { + idx[i] = _idx[i]; + } + } + + /// Constructs from some other Coord + template + CUTLASS_HOST_DEVICE + Coord(Coord other) { + for (int i = 0; i < kRank; ++i) { + idx[i] = other[i]; + } + } + + /// Returns a slice of the Coord which may be larger or smaller in rank + /// than this. + template + CUTLASS_HOST_DEVICE + Coord slice(int start = 0, Index identity = 0) const { + Coord result; + for (int i = 0; i < Slice; ++i) { + if (i + start < kRank) { + result[i] = idx[i + start]; + } + else { + result[i] = identity; + } + } + return result; + } + + /// Returns the index of the dimension with least value + CUTLASS_HOST_DEVICE + int min_dim_index() const { + int i = 0; + for (int j = 1; j < kRank; ++j) { + if (idx[j] < idx[i]) { + i = j; + } + } + return i; + } + + /// Returns the index of the dimension with greatest value + CUTLASS_HOST_DEVICE + int max_dim_index() const { + int i = 0; + for (int j = 1; j < kRank; ++j) { + if (idx[j] > idx[i]) { + i = j; + } + } + return i; + } + + /// Returns true if Coord is non-zero. + CUTLASS_HOST_DEVICE + explicit operator bool() const { + for (int i = 0; i < kRank; ++i) { + if (idx[i]) { + return true; + } + } + return false; + } + + /// Returns true if Coord is uniformly zero. + CUTLASS_HOST_DEVICE + bool operator!() const { + for (int i = 0; i < kRank; ++i) { + if (idx[i]) { + return false; + } + } + return true; + } + + /// Element-wise addition + CUTLASS_HOST_DEVICE + Coord operator+(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + Coord operator-(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + Coord operator*(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + Coord operator/(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; + } + + /// In-place addition + CUTLASS_HOST_DEVICE + Coord& operator+=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] += b.idx[i]; + } + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + Coord& operator-=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] -= b.idx[i]; + } + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + Coord& operator*=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] *= b.idx[i]; + } + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + Coord& operator/=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] /= b.idx[i]; + } + return *this; + } + + /// Member access operator + CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; } + + /// Member access operator + CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; } + + /// Computes the dot product with anotherCoord object + CUTLASS_HOST_DEVICE + LongIndex dot(Coord const& b, LongIndex sum = LongIndex(0)) const { + for (int i = 0; i < kRank; ++i) { + sum += idx[i] * b.idx[i]; + } + return sum; + } + + /// Gets the index of a given Coord element + template + CUTLASS_HOST_DEVICE Index& at() { + return idx[Dim]; + } + + /// Access via index; may limit unrolling potential + CUTLASS_HOST_DEVICE + Index& at(int dim) { return idx[dim]; } + + /// Gets the index of a given Coord element + template + CUTLASS_HOST_DEVICE Index const& at() const { + return idx[Dim]; + } + + /// Access via index; may limit unrolling potential + CUTLASS_HOST_DEVICE + Index const& at(int dim) const { return idx[dim]; } + + /// Determines if two Coord<> objects are equal + CUTLASS_HOST_DEVICE + bool operator==(Coord const& b) const { + bool equal = true; + for (int i = 0; equal && i < kRank; ++i) { + equal = (idx[i] == b.idx[i]); + } + return equal; + } + + /// Not equal + CUTLASS_HOST_DEVICE + bool operator!=(Coord const& b) const { return !(*this == b); } + + /// Clamps a coordinate to a range specified by maximum and minimum values + CUTLASS_HOST_DEVICE + Coord& clamp(Coord const& max, Coord const& min = Coord()) { + for (int i = 0; i < kRank; ++i) { + idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]); + } + return *this; + } + + /// Returns the sum of all elements + CUTLASS_HOST_DEVICE + Index sum() const { + Index sum_(idx[0]); + for (int i = 1; i < kRank; ++i) { + sum_ += idx[i]; + } + return sum_; + } + + /// Returns the product of all elements + CUTLASS_HOST_DEVICE + LongIndex product() const { + LongIndex product_(idx[0]); + for (int i = 1; i < kRank; ++i) { + product_ *= idx[i]; + } + return product_; + } + + /// Less than operator + CUTLASS_HOST_DEVICE + bool operator<(Coord const &b) const { + for (int i = 0; i < kRank; ++i) { + if (!(idx[i] < b[i])) { + return false; + } + } + return true; + } + + /// Less than or equals operator + CUTLASS_HOST_DEVICE + bool operator<=(Coord const &b) const { + for (int i = 0; i < kRank; ++i) { + if (!(idx[i] <= b[i])) { + return false; + } + } + return true; + } + + /// Greater than operator + CUTLASS_HOST_DEVICE + bool operator>(Coord const &b) const { + return !(*this <= b); + } + + /// Greater than or equals operator + CUTLASS_HOST_DEVICE + bool operator>=(Coord const &b) const { + return !(*this < b); + } +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + + +/// Scalar multiplication +template +CUTLASS_HOST_DEVICE +Coord operator*(Index s, Coord coord) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] *= s; + } + return coord; +} + +/// Scalar multiplication +template +CUTLASS_HOST_DEVICE +Coord operator*(Coord coord, Index s) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] *= s; + } + return coord; +} + +/// Scalar division +template +CUTLASS_HOST_DEVICE +Coord operator/(Index s, Coord coord) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] = s / coord[i]; + } + return coord; +} + +/// Scalar division +template +CUTLASS_HOST_DEVICE +Coord operator/(Coord coord, Index s) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] /= s; + } + return coord; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Integer-valued make_Coord +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to make a 2-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<1, T> make_Coord(T _0) { + T values[1] = {_0}; + return Coord<1, T>(values); +} + +/// Helper to make a 2-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<2, T> make_Coord(T _0, T _1) { + T values[2] = {_0, _1}; + return Coord<2, T>(values); +} + +/// Helper to make a 3-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<3, T> make_Coord(T _0, T _1, T _2) { + T values[3] = {_0, _1, _2}; + return Coord<3, T>(values); +} + +/// Helper to make a 4-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) { + T values[4] = {_0, _1, _2, _3}; + return Coord<4, T>(values); +} + +/// Helper to make a 5-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) { + T values[5] = {_0, _1, _2, _3, _4}; + return Coord<5, T>(values); +} + +/// Helper to make a 1-element coordinate +template +CUTLASS_HOST_DEVICE +Coordmake_Coord_with_padding(T _0) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = N - 1; i > 0; --i) { + coord[i] = 0; + } + + coord[0] = _0; + + return coord; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/core_io.h b/server/punica_kernels/include/cutlass/cutlass/core_io.h new file mode 100644 index 00000000..e7c96d05 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/core_io.h @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helpers for printing cutlass/core objects +*/ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ +#pragma once + +#include +#include + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix.h" +#include "cutlass/quaternion.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm_enumerated_types.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Output operator for CUDA built-in dim3 type +inline std::ostream &operator<<(std::ostream &out, dim3 d) { + return out << d.x << ", " << d.y << ", " << d.z; +} + +/// Output operator for CUDA built-in error type +inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { + return out << cudaGetErrorString(error); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline +std::ostream& operator<<(std::ostream& out, Array const& v) { + for (int i = 0; i < Rank; ++i) { + out << (i ? ", " : "") << v[i]; + } + return out; +} + +template +inline +std::ostream& operator<<(std::ostream& out, Coord const& coord) { + for (int i = 0; i < Rank; ++i) { + out << (i ? ", " : "") << coord[i]; + } + return out; +} + +inline +std::istream & operator>>(std::istream &stream, half_t &x) { + float tmp; + stream >> tmp; + x = static_cast(tmp); + return stream; +} + +inline +std::ostream & operator<<(std::ostream &out, half_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { + return out << float(x); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to enable formatted printing of CUTLASS scalar types to an ostream +template +struct ScalarIO { + + /// Value to print + T value; + + /// Default ctor + ScalarIO() { } + + /// Constructs from a value + ScalarIO(T value): value(value) {} +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default printing to ostream +template +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << scalar.value; +} + +/// Printing to ostream of int8_t as integer rather than character +template <> +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << int(scalar.value); +} + +/// Printing to ostream of uint8_t as integer rather than character +template <> +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << unsigned(scalar.value); +} + + +/// Default printing to ostream for MatrixShape +template +inline +std::ostream & operator<<(std::ostream &out, MatrixShape const &matrix_shape) { + out << "cutlass::MatrixShape::(kRow, kColumn) {" + << cutlass::MatrixShape::kRow <<"," + << cutlass::MatrixShape::kColumn <<"}"; + return out; +} + + +/// Prints matrix to ostream +template +std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { + + for (int i = 0; i < Rows; ++i) { + for (int j = 0; j < Columns; ++j) { + ScalarIO element(rhs.at(i, j)); + out << (j ? ", " : "") << element; + } + out << "\\n"; + } + + return out; +} + +template +std::ostream &operator<<(std::ostream &out, Quaternion const &rhs) { + + out << ScalarIO(rhs.w()) << " "; + if (rhs.x() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.x()) << "*i "; + if (rhs.y() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.y()) << "*j "; + if (rhs.z() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.z()) << "*k"; + + return out; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass::gemm namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// +namespace gemm { + +/// Default printing to ostream for GemmShape +template +inline +std::ostream & operator<<(std::ostream &out, GemmShape const &gemm_shape) { + out << "cutlass::gemm::GemmShape::(kM, kN, kK) {" + << cutlass::gemm::GemmShape::kM <<"," + << cutlass::gemm::GemmShape::kN <<"," + << cutlass::gemm::GemmShape::kK << "}"; + return out; +} + +/// Default printing to ostream for GemmCoord +inline +std::ostream & operator<<(std::ostream &out, GemmCoord const &gemm_coord) { + out << "cutlass::gemm::GemmCoord {" + << gemm_coord.m() <<"," + << gemm_coord.n() <<"," + << gemm_coord.k() << "}"; + return out; +} + +} //namespace gemm +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default printing to ostream for PitchLinearShape +template < int Contiguous, int Strided> +inline +std::ostream & operator<<(std::ostream &out, PitchLinearShape const &pitch_linear_shape) { + out << "cutlass::PitchLinearShape:(kContiguous, kStrided) {" + << cutlass::layout::PitchLinearShape::kContiguous <<"," + << cutlass::layout::PitchLinearShape::kStrided <<"}"; + return out; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass::conv namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// +namespace conv { +/// Default printing to ostream for Conv2dProblemSize +inline +std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { + out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl + << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl + << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl + << "groups: (" << problem.groups << ")" << std::endl + << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl + << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl + << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl + << "split_k_slices: (" << problem.split_k_slices << ")" << std::endl + << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; + + return out; +} + + +/// Default printing to ostream for Conv3dProblemSize +inline +std::ostream& operator<<(std::ostream& out, Conv3dProblemSize const& problem) { + out << "NDHWC: (" << problem.N << ", " << problem.D << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl + << "KTRSC: (" << problem.K << ", " << problem.T << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl + << "NZPQK: (" << problem.N << ", " << problem.Z << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl + << "pad_d, pad_h, pad_w: (" << problem.pad_d << ", " << problem.pad_h << ", " << problem.pad_w << ")" << std::endl + << "stride_d, stride_h, stride_w: (" << problem.stride_d << ", " << problem.stride_h << ", " << problem.stride_w << ")" << std::endl + << "dilation_d, dilation_h, dilation_w: (" << problem.dilation_d << ", " << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl + << "split_k_slices: (" << problem.split_k_slices << ") " << std::endl + << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; + + return out; +} + +} // namespace conv +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/cuda_host_adapter.hpp b/server/punica_kernels/include/cutlass/cutlass/cuda_host_adapter.hpp new file mode 100644 index 00000000..644f0a0b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/cuda_host_adapter.hpp @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Interface betweeen a CUTLASS device-wide operator and CUDA. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" + +#include "cutlass/platform/platform.h" +#if ! defined(__CUDACC_RTC__) +#include +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Macro-level guard for CUDA Host Adapter +// +#if !defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) +#define CUTLASS_ENABLE_CUDA_HOST_ADAPTER false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and +/// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API. +struct CudaHostAdapter { + + /// Limit the number of kernels + static constexpr int32_t kMaximumKernelCount = 4; + + /// Maximum cluster size + static constexpr int MaxClusterSize = 32; + + // + // Data members + // + + /// Handles + void *kernel_handles[kMaximumKernelCount]; + int32_t kernel_count = 0; + + // + // Methods + // + + /// Ctor + CudaHostAdapter() = default; + + /// Dtor + virtual ~CudaHostAdapter() {} + + /// Copy Ctor + inline CudaHostAdapter(const CudaHostAdapter & rhs): + kernel_count(rhs.kernel_count) + { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } + + /// Copy Assignment + inline CudaHostAdapter& operator=(const CudaHostAdapter & rhs) { + + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + kernel_count = rhs.kernel_count; + return *this; + } + + /// Move ctor + inline CudaHostAdapter(CudaHostAdapter && rhs): + kernel_count(rhs.kernel_count) + { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } + + /// Move assignment + inline CudaHostAdapter& operator=(CudaHostAdapter && rhs) { + + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + + kernel_count = rhs.kernel_count; + + return *this; + } + + /// Ctor + inline CudaHostAdapter( + void **kernel_handles_, + int32_t kernel_count_ + ): + kernel_count(kernel_count_) + { + CUTLASS_ASSERT(kernel_count >= 0); + for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = kernel_handles_[i]; + } + } + + /// Returns true if the CudaHostAdapter is empty (kernel_count == 0) + inline bool empty() const { return !kernel_count; } + + /// Returns kernel_count + inline size_t size() const { return static_cast(kernel_count); } + + /// Queries the occupancy of a kernel + virtual Status query_occupancy( + int32_t *device_sms, + int32_t *sm_occupancy, + int32_t kernel_index, + int32_t thread_count, + int32_t smem_size) const = 0; + + /// Launches a kernel without using Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + + /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + +protected: + + /** + * Fills a buffer in Global Memory with a byte sequence copied from host memory. + * This function can be overriden to dispatch to the appropriate cuMemsetD*Async API + */ + virtual Status memsetDeviceImpl( + void* destination, ///< Device memory pointer to be filled + void const* fill_value, ///< Value to be filled in the buffer + size_t fill_size, ///< Size of the data type to be used for filling the buffer + size_t count, ///< Number of elements of size fill_size + cudaStream_t stream) const = 0; + +public: + + /// Fills a buffer in Global Memory with a byte sequence copied from host memory + template + Status memsetDevice( + void* destination, + FillValueType fill_value, + size_t count, + cudaStream_t stream) const + { + return this->memsetDeviceImpl( + destination, + &fill_value, + sizeof(FillValueType), + count, + stream); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/cutlass.h b/server/punica_kernels/include/cutlass/cutlass/cutlass.h new file mode 100644 index 00000000..a0070e84 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/cutlass.h @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic include for CUTLASS. +*/ + +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + +#pragma once + +#include "cutlass/detail/helper_macros.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/// Status code returned by CUTLASS operations +enum class Status { + kSuccess, ///< Operation was successful. + kErrorMisalignedOperand, ///< operands fail alignment requirements. + kErrorInvalidDataType, ///< DataType fails requirement. + kErrorInvalidLayout, ///< Layout fails alignment requirement. + kErrorInvalidProblem, ///< Specified problem size is not supported by operator. + kErrorNotSupported, ///< Operation is not supported on current device. + kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null. + kErrorInternal, ///< An error within CUTLASS occurred. + kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for. + kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old. + kErrorMemoryAllocation, ///< Kernel launch failed due to insufficient device memory. + kInvalid ///< Status is unspecified. +}; + +/// Convert cutlass status to status strings +CUTLASS_HOST_DEVICE +static char const* cutlassGetStatusString(cutlass::Status status) { + switch (status) { + case cutlass::Status::kSuccess: + return "Success"; + case cutlass::Status::kErrorMisalignedOperand: + return "Error Misaligned Operand"; + case cutlass::Status::kErrorInvalidDataType: + return "Error Invalid Data Type"; + case cutlass::Status::kErrorInvalidLayout: + return "Error Invalid Layout"; + case cutlass::Status::kErrorInvalidProblem: + return "Error Invalid Problem"; + case cutlass::Status::kErrorNotSupported: + return "Error Not Supported"; + case cutlass::Status::kErrorWorkspaceNull: + return "Error Workspace Null"; + case cutlass::Status::kErrorInternal: + return "Error Internal"; + case cutlass::Status::kErrorInsufficientDriver: + return "Error Insufficient Driver"; + case cutlass::Status::kErrorArchMismatch: + return "Error Architecture Mismatch"; + case cutlass::Status::kErrorMemoryAllocation: + return "Error Memory Allocation failed"; + case cutlass::Status::kInvalid: break; + } + + return "Invalid status"; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static const int NumThreadsPerWarp = 32; +static const int NumThreadsPerWarpGroup = 128; +static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; +static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static const int NumThreadsPerQuad = 4; +static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper function to return true when called by thread 0 of threadblock 0. +CUTLASS_HOST_DEVICE bool thread0() { + #if defined(__CUDA_ARCH__) + return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z); + #else + return false; + #endif +} + +/// Returns a lane index in the warp. The threads in warp may not be convergent +CUTLASS_DEVICE +int canonical_lane_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x % NumThreadsPerWarp; + #else + return 0; + #endif +} + +/// Returns a warp-uniform value indicating the canonical warp index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_idx_sync() { + #if defined(__CUDA_ARCH__) + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); + #else + return 0; + #endif +} + +/// Returns a warp index in the CTA. The threads in warp may not be convergent +/// As it doesn't sync the warp, it faster and allows forward progress +CUTLASS_DEVICE +int canonical_warp_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x / NumThreadsPerWarp; + #else + return 0; + #endif +} + +/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_group_idx() { + #if defined(__CUDA_ARCH__) + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); + #else + return 0; + #endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/detail/collective.hpp b/server/punica_kernels/include/cutlass/cutlass/detail/collective.hpp new file mode 100644 index 00000000..646b5ce5 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/detail/collective.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/container/tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct deduce_mixed_width_dtype { +static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represent Operand, Scale, and Bias, respectively."); + +private: + using underlying_tuple = cute::conditional_t::value, Tuple, cute::tuple>; + static constexpr size_t valid_index = cute::min(I, cute::tuple_size_v - 1); + +public: + using type = cute::conditional_t<(I < cute::tuple_size_v), + cute::tuple_element_t, + void>; +}; + +template +using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype::type; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective diff --git a/server/punica_kernels/include/cutlass/cutlass/detail/dependent_false.hpp b/server/punica_kernels/include/cutlass/cutlass/detail/dependent_false.hpp new file mode 100644 index 00000000..76e52d2b --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/detail/dependent_false.hpp @@ -0,0 +1,86 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +/// @brief A bool constant that depends on one or more template parameters. +/// +/// For more detailed documentation and use cases, +/// please see `dependent_false` below. +template +inline constexpr bool dependent_bool_value = Value; + +/// @brief An always-false value that depends on one or more template parameters. +/// +/// This exists because `static_assert(false);` always fails, +/// even if it occurs in the `else` branch of an `if constexpr`. +/// The following example shows how to use `dependent_false` in that case. +/// +/// @code +/// template +/// void foo (T t) +/// { +/// if constexpr (std::is_integral_v) { +/// do_integer_stuff(t); +/// } +/// else if constexpr (std::is_floating_point_v) { +/// do_floating_point_stuff(t); +/// } +/// else { +/// static_assert(dependent_false, "T must be " +/// "an integral or floating-point type."); +/// } +/// } +/// @endcode +/// +/// This implements the C++ Standard Library proposal P1830R1. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf +/// +/// That proposal is under review as of 2022/12/05. +/// The following link shows P1830's current review status. +/// +/// https://github.com/cplusplus/papers/issues/572 +/// +/// P2593R0 proposes an alternate solution to this problem, +/// that would change the C++ language itself. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +/// +/// For headers in this library, however, we only consider library solutions +/// as work-arounds for future C++ features. +template +inline constexpr bool dependent_false = dependent_bool_value; + +} // end namespace cutlass::detail diff --git a/server/punica_kernels/include/cutlass/cutlass/detail/helper_macros.hpp b/server/punica_kernels/include/cutlass/cutlass/detail/helper_macros.hpp new file mode 100644 index 00000000..d51f8f19 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/detail/helper_macros.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Helper macros for the CUTLASS library +*/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +#ifdef CUTLASS_NAMESPACE +#define concat_tok(a, b) a ## b +#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) +#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#else +#define CUTLASS_HOST_DEVICE inline +#define CUTLASS_DEVICE inline +#endif + +#define CUTLASS_HOST __host__ +#define CUTLASS_GLOBAL __global__ static + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) +{ } + +#if defined(__GNUC__) + #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) +#else + #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) +#endif + +#ifdef _MSC_VER +// Provides support for alternative operators 'and', 'or', and 'not' +#include +#endif // _MSC_VER + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#if defined(__CUDA_ARCH__) + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } + #else + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } + #endif +#else + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) + #else + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) + #endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + + +#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED +#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 +#endif + + +// CUDA 10.1 introduces the mma instruction +#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) +#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_ASSERT(x) assert(x) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. +#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) + #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) + #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") + #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") + #else + #define CUTLASS_PRAGMA_UNROLL #pragma unroll + #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 + #endif + + #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL + +#else + + #define CUTLASS_PRAGMA_UNROLL + #define CUTLASS_PRAGMA_NO_UNROLL + #define CUTLASS_GEMM_LOOP + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +#define CUTLASS_THREAD_LOCAL thread_local +#else +#define CUTLASS_THREAD_LOCAL +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if (201700L <= __cplusplus) +#define CUTLASS_CONSTEXPR_IF_CXX17 constexpr +#define CUTLASS_CXX17_OR_LATER 1 +#else +#define CUTLASS_CONSTEXPR_IF_CXX17 +#define CUTLASS_CXX17_OR_LATER 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +}; // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/detail/layout.hpp b/server/punica_kernels/include/cutlass/cutlass/detail/layout.hpp new file mode 100644 index 00000000..16165442 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/detail/layout.hpp @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" + +#include "cute/layout.hpp" +#include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90_tma.hpp" +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For each cutlass::layout, provides its corresponding cute stride types, 64b by default + +template +struct TagToStrideA { + using type = L; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::ColumnMajor; +}; + +template +struct TagToStrideB { + using type = L; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t>; + using tag = layout::ColumnMajor; +}; + +// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default +// Used by pointer array and grouped gemm +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [M, N, L] +template +struct TagToStrideC : TagToStrideA { }; + +// Conv: Maps to modes ((P,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes (PN, C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<0>>; +}; + +// Conv: Maps to modes ((P,Q,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes (PQN, C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<0>>; +}; + +// Conv: Maps to modes ((P,Q,Z,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes (PQZN, C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S,R), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S,R,T), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t, int64_t>, cute::Int<0>>; +}; + +// Convenience aliases +template +using TagToStrideA_t = typename TagToStrideA::type; + +template +using TagToStrideB_t = typename TagToStrideB::type; + +template +using TagToStrideC_t = typename TagToStrideC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For 2.x compatibility APIs, provide stride->layout tag mappers + +template +constexpr bool +is_major(Stride = {}) { + // Account for stride types with and without batch mode and batch modes with static zero stride + return cute::is_constant<1, decltype(cute::front(cute::get(cute::remove_pointer_t{})))>::value; +} + +// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices +template +constexpr +auto +stride_to_layout_tag_A() { + if constexpr (is_major<0, StrideA>()) { // M major + return layout::ColumnMajor{}; + } + else { // K major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_B() { + if constexpr (is_major<0, StrideB>()) { // N major + return layout::RowMajor{}; + } + else { // K major + return layout::ColumnMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_C() { + if constexpr (is_major<0, StrideC>()) { // M major + return layout::ColumnMajor{}; + } + else { // N major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// Utilities to map Stride back on to their corresponding layout tags +template +struct StrideToLayoutTagA { + using type = decltype(detail::stride_to_layout_tag_A()); +}; + +template +struct StrideToLayoutTagB { + using type = decltype(detail::stride_to_layout_tag_B()); +}; + +template +struct StrideToLayoutTagC { + using type = decltype(detail::stride_to_layout_tag_C()); +}; + +// Convenience aliases +template +using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; + +template +using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; + +template +using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Inspects a tiled copy and whether its copy engine is TMA or not +template +constexpr bool is_tma_copy_engine() { + if constexpr (cute::is_void_v) { + return false; + } + else { + if constexpr ( cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + ) { + return true; + } + } + return false; +} + +template +struct RawDtype { using type = X; }; + +template +struct RawDtype> { using type = typename X::raw_type; }; + + +// Inspects a TiledCopy and returns its alignment in terms of element count +template +constexpr int +get_alignment_count_from_gmem_tiled_copy() { + + if constexpr (cute::is_void_v) { + return 1; + } + + // Account for ElementC = void kernels + else if constexpr (cute::is_void_v) { + return 0; + } + + else { + // For TMA tiled copies, we know the alignment has to be 128 bits + if constexpr (is_tma_copy_engine()) { + return 128 / sizeof_bits::value; + } + else { + // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN + return GmemTiledCopy::NumValSrc; + } + } +} + +// Return the shape that is associated with stride-1 mode, or 1 if not found +template +CUTLASS_HOST_DEVICE constexpr +auto +get_contiguous_shape(Shape const & shape, Stride const & stride) { + using namespace cute; + auto idx = find_if(append(flatten(stride), _1{}), [](auto s){ return is_constant<1,decltype(s)>{}; }); + return get(append(flatten(shape), _1{})); +} + +// Check if tensor shape satisfies a given major alignment +template +CUTLASS_HOST_DEVICE constexpr +bool +check_alignment(Shape const & shape, Stride const & stride) { + return is_major<0>(stride) + ? get_contiguous_shape(cute::get<0>(shape), cute::get<0>(stride)) % Alignment == 0 + : get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0; +} + +// Check if tensor shape satisfies a given major alignment + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(cute::Swizzle) { + static_assert(B >= 0 and M >= 0); + return size_t(1) << size_t(B + M + cute::abs(S)); +} + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(Layout layout) { + return alignment_for_swizzle(cute::detail::get_swizzle_portion(layout)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/server/punica_kernels/include/cutlass/cutlass/detail/mma.hpp b/server/punica_kernels/include/cutlass/cutlass/detail/mma.hpp new file mode 100644 index 00000000..ab36e862 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/detail/mma.hpp @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cute/layout.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IsSparseTensorOp : cute::false_type { }; + +// The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. +template +struct get_operator_class { + static constexpr bool is_sparse_op = IsSparseTensorOp::value; + static constexpr bool is_tensor_op = cute::size<0>(typename TiledMma::AtomShape_MNK{}) >= 8; + using type = cute::conditional_t< + is_tensor_op, + cute::conditional_t< + is_sparse_op, + cutlass::arch::OpClassSparseTensorOp, + cutlass::arch::OpClassTensorOp + >, + cutlass::arch::OpClassSimt + >; +}; + +template +using get_operator_class_t = typename get_operator_class::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/server/punica_kernels/include/cutlass/cutlass/device_kernel.h b/server/punica_kernels/include/cutlass/cutlass/device_kernel.h new file mode 100644 index 00000000..ba875a75 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/device_kernel.h @@ -0,0 +1,113 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for generic CUTLASS kernel. +*/ + +#pragma once + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTLASS_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+ +#if defined(CUTLASS_GRID_CONSTANT_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) +# define CUTLASS_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTLASS_GRID_CONSTANT) +# if defined(CUTLASS_GRID_CONSTANT_ENABLED) +# define CUTLASS_GRID_CONSTANT __grid_constant__ +# else +# define CUTLASS_GRID_CONSTANT +# endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +void Kernel(typename Operator::Params params) { + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + // Declare pointer to dynamic shared memory. + typename Operator::SharedStorage *shared_storage = + reinterpret_cast(SharedStorageBase); + + Operator op; + + op(params, *shared_storage); +} + + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +void Kernel2(typename Operator::Params params) { + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + // Declare pointer to dynamic shared memory. + typename Operator::SharedStorage *shared_storage = + reinterpret_cast(SharedStorageBase); + + Operator::invoke(params, *shared_storage); +} + + +//////////////////////////////////////////////////////////////////////////////// +// +// 3.0 specific launch +// +//////////////////////////////////////////////////////////////////////////////// + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +#ifdef __CUDACC__ +// Enclosing this in __CUDACC__ suppresses MSVC warnings. +__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +#endif // __CUDACC__ +void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) +{ + // Dynamic shared memory base pointer + extern __shared__ char smem[]; + Operator op; + op(params, smem); +} + +//////////////////////////////////////////////////////////////////////////////// +} /// namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/builders/sm90_builder.inl b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/builders/sm90_builder.inl new file mode 100644 index 00000000..b6a41af4 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -0,0 +1,830 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_traits_sm90.hpp" +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90.hpp" + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the parameterized dispatch policy for the TMA epilogue +template +constexpr auto +sm90_get_tma_dispatch_policy() { + using namespace cute; + + constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v ? 256 : 128); + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); + constexpr bool DelayTmaStore = is_void_v; // TMA store delay performs worse with residual loads + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + return Sm90TmaWarpSpecialized{}; +} + +// Returns the smem layout atom to be used for C or D matrix +template +constexpr auto +sm90_get_epilogue_smem_swizzle_layout_atom() { + using namespace cute; + + // ColMajor C/D (M-major) + if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + // RowMajor C/D (N-major) + else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout."); + } +} + +// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. +template +constexpr auto +sm90_compute_tile_shape_or_override() { + if constexpr (cute::is_same_v) { + auto epi_tile = [&] () { + if constexpr (detail::sm90_is_cooperative_v) { + auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::min(_32{}, size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else if constexpr (detail::sm90_is_warp_specialized_v) { + constexpr int N_perf = sizeof_bits_v == 8 ? 64 : 32; + auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::min(Int{}, size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported schedule."); + } + }(); + + return cute::transform(epi_tile, seq<0,1>{}, + [] (auto epi_tiler, auto I) { + auto cta_tiler = make_layout(get(TileShape_MNK{})); + // This is a multimodal CTA tiler, transform before returning + if constexpr (depth(cta_tiler) > 0) { + // This is an implicit multimodal tiler, match profile and return + if constexpr (tuple_size_v == 1) { + return make_tile(epi_tiler); + } + // This is an explicit multimodal tiler, compose out epi tiler + else { + return composition(cta_tiler, epi_tiler); + } + } + // This is a flat CTA tiler, no need for transformation + else { + return epi_tiler; + } + }); + } + else if constexpr (cute::is_tuple::value) { + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + + static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(M == 64 && detail::sm90_is_warp_specialized_v || + M == 128 && detail::sm90_is_cooperative_v, "Unsupported tile shape"); + static_assert(N % 16 == 0, "Unsupported tile shape"); + + return epi_tile; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); + } +} + +// Selects the largest vectorized smem store atom available +template +constexpr auto +sm90_get_smem_store_op_for_accumulator() { + using namespace cute; + + if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) { + return SM90_U16x8_STSM_T{}; + } + else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) { + return SM90_U32x4_STSM_N{}; + } + else { + // auto-vectorizing store + return AutoVectorizingCopyWithAssumedAlignment{}; + } +} + +// Selects the largest vectorized smem load atom available +template +constexpr auto +sm90_get_smem_load_op_for_source() { + using namespace cute; + + // Reuse the logic from smem store selector + using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator()); + + if constexpr (cute::is_same_v) { + return SM75_U16x8_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U32x4_LDSM_N{}; + } + else { + // auto-vectorizing load + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } +} + +// callbacks builder with TMA aux out +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && sizeof_bits_v == 1> +> { + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem + >; +}; + +// Helper for building TMA warp-specialized collective epilogues, specialized by +// the fusion operation performed and the dispatch policy to use. +template < + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class FusionOpOrCallbacks, + class DispatchPolicy +> +struct Sm90TmaBuilderImpl { + // Passing void D disables destination store + smem allocation + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + + // Passing void C disables source load + smem allocation + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using CopyOpS2G = cute::conditional_t, + SM90_TMA_STORE_IM2COL, + SM90_TMA_STORE + >; + using CopyOpG2S = cute::conditional_t, + SM90_TMA_LOAD_IM2COL, + SM90_TMA_LOAD + >; + + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks + // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination + using FusionCallbacks = + typename CallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator + >::Callbacks; + + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< + DispatchPolicy, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD_, + GmemStrideTypeD, + FusionCallbacks, + CopyOpG2S, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + CopyOpS2G, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()) + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Descriptor classes for defining EVT nodes +// Some of the epilogue visitor nodes require non-intuitive template arguments +// such as CopyOpS2R for AuxLoad node. Traditionaly, these are resolved by the +// builder classes. Here we provide a set of descriptor classes that resolve +// these template arguments from more intuitive types such as Stride, Layout + +// Get TileShape, EpilogueTile, Dispatch Policy, StagesC, and STagesD +template< + typename TileShape_MNK, + typename EpilogueTileType, + typename ElementC, + typename ElementD, + typename Schedule +> +struct EpilogueDescriptor { + using TileShape = TileShape_MNK; + using EpilogueTile = + decltype( + detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule, TileShape_MNK + >() + ); + using DispatchPolicy = + decltype( + detail::sm90_get_tma_dispatch_policy< + TileShape_MNK, EpilogueTile, + ElementC, ElementD, Schedule + >() + ); + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpS2R = + decltype(detail::sm90_get_smem_load_op_for_source()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpR2S = + decltype(detail::sm90_get_smem_store_op_for_accumulator()); +}; + +template< + typename EpilogueDescriptor, + typename ElementVector +> +struct RowBroadcastDescriptor { + constexpr static int Stages = ceil_div( + EpilogueDescriptor::StagesC, + size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) + ) + 1; + + using Element = ElementVector; +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + +// No-smem builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + FloatRoundStyle RoundStyle +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + fusion::LinearCombination, + cute::enable_if_t || + cute::is_same_v >> { + + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + static constexpr int FragmentSize = 1; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, + ScaleType, RoundStyle, ElementC>; + + using CollectiveOp = cute::conditional_t< + cute::is_same_v, + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueDefault>>, + // Epilogue for Ptr-Array and Grouped Gemm + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogueArray< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + Schedule>> + >; +}; + +// Tma warp-specialized builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class FusionOperation +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + Schedule, + FusionOperation, + cute::enable_if_t || + cute::is_same_v >> { +private: + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + using EpilogueTile_MN = + decltype(detail::sm90_compute_tile_shape_or_override()); + using DispatchPolicy = + decltype(detail::sm90_get_tma_dispatch_policy()); + +public: + using CollectiveOp = + typename detail::Sm90TmaBuilderImpl< + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + FusionOperation, + DispatchPolicy + >::CollectiveOp; +}; + +// Auto builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class FusionOperation +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleAuto, + FusionOperation, + void> { +private: + static_assert(cute::is_same_v>, + "Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead."); + + // Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance) + // since TMA epilogues are not compatible with non-TMA non-WS mainloops + using EpilogueSchedule = NoSmemWarpSpecialized; + using _CollectiveBuilder = CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueSchedule, + FusionOperation + >; + +public: + using CollectiveOp = typename _CollectiveBuilder::CollectiveOp; +}; + +// DEPRECATED Tma warp-specialized builder for elementwise fusion +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class UnusedFusionOp +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + UnusedFusionOp, + cute::enable_if_t || + cute::is_base_of_v >> { +private: + using FusionOp = + fusion::LinCombEltAct; + using ImplSchedule = + cute::conditional_t, + TmaWarpSpecialized, TmaWarpSpecializedCooperative>; + +public: + using CollectiveOp = + typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + ImplSchedule, + FusionOp + >::CollectiveOp; +}; + +// DEPRECATED Tma warp-specialized builder for bias + elementwise fusion +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class UnusedFusionOp +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]] +CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + UnusedFusionOp, + cute::enable_if_t || + cute::is_base_of_v >> { +private: + using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule, TileShape_MNK>()); + // MSVC doesn't seem to be able to deduce DispatchPolicy correctly if it's + // defined as decltype of a detail::sm90_get_tma_dispatch_policy call. + // Instead, we paste in the contents of that function. A natural refactoring + // would be to create a type alias in the detail namespace. + using DispatchPolicy = Sm90TmaWarpSpecialized< + /* StagesC = */ size(shape_div(take<0, 2>(TileShape_MNK{}), EpilogueTile_MN{})), + /* StagesD = */ 2, + /* FragmentSize = */ size(EpilogueTile_MN{}) / (detail::sm90_is_cooperative_v ? 256 : 128), + /* ReuseSmemC = */ sizeof_bits_v == sizeof_bits_v, + false + >; + + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename Schedule::ElementT, EpilogueTile_MN>()); + using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename Schedule::ElementT>()); + using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute + >; + using FusionCallbacksAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux + >; + + using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct< + Schedule::template ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementBias, ElementC_, ElementCompute + >; + using FusionCallbacksNoAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN + >; + + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = gemm::TagToStrideC_t; + using GmemStrideTypeD = gemm::TagToStrideC_t; + +public: + using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise< + DispatchPolicy::StagesC, + DispatchPolicy::StagesD, + DispatchPolicy::FragmentSize, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + cute::conditional_t, + SM90_TMA_LOAD, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + SM90_TMA_STORE, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()) + >; +}; + +// CollectiveBuilder that transposed epilogue below is used for sm90 gmma RS TT kernels +// since swapping NNN kernels input matrix and transposing its output at the same time then +// we can get TTN kernel. +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + FloatRoundStyle RoundStyle +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + cutlass::gemm::EpilogueTransposed, + fusion::LinearCombination, + void> { + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + static constexpr int FragmentSize = 1; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, + ScaleType, RoundStyle, ElementC>; + + using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueTransposed> + >; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/collective_builder.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/collective_builder.hpp new file mode 100644 index 00000000..10aad81d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/collective_builder.hpp @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify epilogue subtile shape or dispatch to automatic computation of subtile shape +struct EpilogueTileAuto {}; + +// Used to let the builder pick the epilogue schedule automatically. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct EpilogueScheduleAuto {}; +struct EpilogueIm2ColScheduleAuto {}; + +template < + class ArchTag, + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +// helper sub-builder for epilogue fusion callbacks (for internal use by CollectiveBuilder only) +namespace detail { + +// callbacks builder with operation tag +template< + class DispatchPolicy, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class = void +> +struct CallbacksBuilder { + using Callbacks = fusion::FusionCallbacks; +}; + +// callbacks builder with callbacks passthrough +template < + class DispatchPolicy, + class FusionCallbacks, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + DispatchPolicy, + FusionCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t> +> { + using Callbacks = FusionCallbacks; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_builder.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/collective_epilogue.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/collective_epilogue.hpp new file mode 100644 index 00000000..d61f59f7 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/collective_epilogue.hpp @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +class CollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "detail.hpp" +#include "default_epilogue.hpp" +#include "default_epilogue_array.hpp" +#include "epilogue_tensor_broadcast.hpp" +#include "sm70_epilogue_vectorized.hpp" +#include "sm90_epilogue_tma_warpspecialized.hpp" +#include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/default_epilogue.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/default_epilogue.hpp new file mode 100644 index 00000000..bbeeacac --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/default_epilogue.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class DefaultEpilogue { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + using TensorStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + // Note: SharedStorage is unused for DefaultEpilogue + CUTLASS_HOST_DEVICE + DefaultEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/default_epilogue_array.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/default_epilogue_array.hpp new file mode 100644 index 00000000..190d7699 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/trace.h" + +#include "cutlass/cuda_host_adapter.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Applies an element wise operation to all elements within the fragment +// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class DefaultEpilogueArray { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using UnderlyingStrideC = cute::remove_pointer_t; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using UnderlyingStrideD = cute::remove_pointer_t; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(rank(UnderlyingStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(UnderlyingStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + DefaultEpilogueArray(Params const& params_) + : params(params_) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + UnderlyingStrideC stride_c; + UnderlyingStrideD stride_d; + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + stride_c = detail::get_epilogue_stride(params.dC[l_coord]); + } + stride_d = detail::get_epilogue_stride(params.dD[l_coord]); + } + else { + stride_c = detail::get_epilogue_stride(params.dC); + stride_d = detail::get_epilogue_stride(params.dD); + } + + // Represent the full output tensor + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/detail.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/detail.hpp new file mode 100644 index 00000000..8754a8cc --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/detail.hpp @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cute/util/type_traits.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool +is_m_major() { + return cutlass::gemm::detail::is_major<0,Stride>(); +} + +template +constexpr bool +is_n_major() { + return cutlass::gemm::detail::is_major<1,Stride>(); +} + +template +constexpr bool +is_im2col() { + return cute::is_same_v> + || cute::is_same_v> + || cute::is_same_v>; +} + +using cutlass::atomic_maximum; + +template +static constexpr int elements_per_access_v = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; + +template +static constexpr bool sm90_is_cooperative_v = + cute::is_base_of_v; + +template +static constexpr bool sm90_is_warp_specialized_v = + cute::is_base_of_v; + +template +static constexpr bool is_im2col_mode = + cute::is_same_v || + cute::is_same_v || + cute::is_same_v; + +template +struct EmptyStorage { + CUTLASS_HOST_DEVICE + T* data() { return nullptr; } +}; + +template +CUTLASS_HOST_DEVICE +auto get_epilogue_stride(Stride stride){ + if constexpr (cute::is_base_of_v) { + return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); + } + else { + return stride; + } +} + +template +struct IsThreadEpilogueOpWithBias { + static constexpr bool value = false; + using type = typename ThreadEpilogueOp::ElementCompute; +}; + +template +struct IsThreadEpilogueOpWithBias > { + static constexpr bool value = true; + using type = typename ThreadEpilogueOp::ElementBias; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaling { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaling > { + static constexpr bool value = true; +}; + +template +struct IsThreadEpilogueOpWithActivation { + static constexpr bool value = false; + using type = void; +}; + +template +struct IsThreadEpilogueOpWithActivation > { + static constexpr bool value = true; + using type = typename ThreadEpilogueOp::ActivationFn; +}; + +// Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels +template +class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { +public: + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using LoadPipeline = cutlass::PipelineTransactionAsync<0>; + using LoadPipelineState = cutlass::PipelineState<0>; + constexpr static uint32_t TmaTransactionBytes = 0; + + using StorePipeline = cutlass::PipelineTmaStore<0>; + using StorePipelineState = cutlass::PipelineState<0>; + + using TensorStorage = typename EpilogueOp::SharedStorage; + using PipelineStorage = typename LoadPipeline::SharedStorage; + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment([[maybe_unused]] TileShapeMNK) { + return 1; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment([[maybe_unused]] TileShapeMNK) { + return 1; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { + } + + // ctor inheritance + using EpilogueOp::EpilogueOp; + + CUTLASS_HOST_DEVICE + Sm90TmaWarpSpecializedAdapter( + typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorStorage& shared_tensors) + : EpilogueOp(params) { } + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return false; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, + [[maybe_unused]] TileShapeMNK tile_shape_MNK, + [[maybe_unused]] TileCoordMNKL tile_coord_mnkl, + [[maybe_unused]] TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + [[maybe_unused]] TensorStorage& shared_tensors, + [[maybe_unused]] int subtile_idx=-1) + { + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) + { + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE auto + store( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_index = -1) + { + constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); + })); + + constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); + })); + + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + (*this)( + problem_shape_mnkl, + tile_shape_MNK, + tile_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + reinterpret_cast(&shared_tensors)); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + +}; + +} // namespace detail +} // namespace collective +} // namespace epilogue +} // namespace cutlass diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp new file mode 100644 index 00000000..c870b706 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -0,0 +1,271 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor for performing tensor-tensor broadacasts atop existing epilogues. + + Concretely, the opeartion performed is the following: + UnaryOp( + BinaryOp1( + BinaryOp0( + Activation((alpha * A @ B) + bias), + beta * C0 + ), + beta * C1 + ) + ) + + where: + - C0 and C1 have the same extents as the output + - BinaryOp0 and BinaryOp1 perform elementwise binary operations + - UnaryOp is an elementwise operation +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collective epilogue that applies elementwise tensor-tensor operations atop other epilogues +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_, + bool PerColumnBias_ = false +> +class EpilogueTensorBroadcast { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename ThreadEpilogueOp::ElementBias; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static constexpr bool IsBinaryOp0Enabled = ThreadEpilogueOp::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = ThreadEpilogueOp::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = ThreadEpilogueOp::IsUnaryOpEnabled; + + static constexpr bool PerColumnBias = PerColumnBias_; + using BiasStride = typename cute::conditional_t, Stride<_1, _0, _0>>; + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias* ptr_Bias = nullptr; + ElementC* ptr_C0 = nullptr; + ElementC* ptr_C1 = nullptr; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + EpilogueTensorBroadcast(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source0_needed() || epilogue_op.is_source1_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + auto stride_bias = detail::get_epilogue_stride(BiasStride{}); + + // Represent the full output tensor + Tensor mC0_mnl = make_tensor(make_gmem_ptr(params.ptr_C0), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mC1_mnl = make_tensor(make_gmem_ptr(params.ptr_C1), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), stride_bias); // (m,n,l) + + Tensor gC0_mnl = local_tile(mC0_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gC1_mnl = local_tile(mC1_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this thread block is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC0 = gC0_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC1 = gC1_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC0 = thr_mma.partition_C(gC0); // (VEC,THR_M,THR_N) + Tensor tCgC1 = thr_mma.partition_C(gC1); // (VEC,THR_M,THR_N) + Tensor tCgBias = thr_mma.partition_C(gBias); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, + "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC0) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgC1) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + CUTE_STATIC_ASSERT_V(size(tCgBias) == size(accumulators), + "Accumulator count must have the same destination element count."); + + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + bool bias_needed = params.ptr_Bias != nullptr; + bool c0_needed = (params.ptr_C0 != nullptr) && epilogue_op.is_source0_needed(); + bool c1_needed = (params.ptr_C1 != nullptr) && epilogue_op.is_source1_needed(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + ElementBias bias = bias_needed ? tCgBias(i) : ElementBias(0); + ElementC c0 = c0_needed ? tCgC0(i) : ElementC(0); + ElementC c1 = c1_needed ? tCgC1(i) : ElementC(0); + + tCgD(i) = epilogue_op(accumulators(i), c0, c1, bias); + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp new file mode 100644 index 00000000..be19944d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_ +> +class Epilogue { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sC = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sC to match the accumulator partitioning + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto tC = tiled_r2s.get_thread_slice(thread_idx); + Tensor tCaC = tC.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tCsC = tC.partition_D(sC); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sC), size<1>(sC)); + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sC, gC, and gD for the output + auto tiled_s2r = TiledCopyS2R{}; + auto tD = tiled_s2r.get_thread_slice(thread_idx); + Tensor tDsC = tD.partition_S(sC); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tDgC = tD.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tDgD = tD.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tDrC = make_tensor(take<0,3>(shape(tDgC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tDrD = make_tensor(shape(tDrC)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tDcD = tD.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tCaC) % size<4>(tDgC) == 0); // TILE_N divides MMA_N + CUTE_STATIC_ASSERT(typename TiledCopyS2R::TiledNumThr{} == size<0>(typename TiledMma::AtomLayoutC_TV{})); + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("sC : "); print(sC.layout()); print("\n"); + print("\n"); + print("tCsC : "); print(tCsC.layout()); print("\n"); + print("tCaC : "); print(tCaC.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tDsC : "); print(tDsC.layout()); print("\n"); + print("tDrC : "); print(tDrC.layout()); print("\n"); + print("\n"); + print("tDrD : "); print(tDrD.layout()); print("\n"); + print("tDgC : "); print(tDgC.layout()); print("\n"); + print("tDgD : "); print(tDgD.layout()); print("\n"); + print("\n"); + } +#endif + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) + { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) + { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tCsC); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) { + int mma_m = step_m * size<1>(tCsC) + pipe_m; + int mma_n = step_n * size<2>(tCsC) + pipe_n; + + copy(tiled_r2s, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tiled_s2r, tDsC, tDrC); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tDgDmn = tDgD(_,_,_,step_m,step_n); + Tensor tDcDmn = tDcD(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tDgCmn = tDgC(_,_,_,step_m,step_n); + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tDgDmn); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tDgDmn); ++n) + { + // Predication + if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && + get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) + { + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tDrC); ++i) { + tDrD(i,m,n) = epilogue_op(tDrC(i,m,n), tDgCmn(i,m,n)); + } + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); + } + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tDrC); ++i) { + tDrD(i) = epilogue_op(tDrC(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tDgDmn); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tDgDmn); ++n) + { + // Predication + if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && + get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) + { + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp new file mode 100644 index 00000000..e509fe18 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,813 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + Sm90TmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm90TmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using SmemElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using SmemElementC = cute::conditional_t; // prevents void ref breakages + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t, + EmptyType>; + using SmemDStorage = cute::conditional_t, + EmptyType>; + + struct TensorStorageImpl: cute::tuple { + using Base = cute::tuple; + + constexpr decltype(auto) + smem_C() { + return cute::get<0>(static_cast(*this)); + } + + constexpr decltype(auto) + smem_D() { + return cute::get<1>(static_cast(*this)); + } + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideC{}, int32_t(0)), StrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideD{}, int32_t(0)), StrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + // For fprop/dgrad kernel, problem shape M is multimodal which should be linearized under tiled mode + auto M_C = conditional_return(M, size(M)); + auto M_D = conditional_return(M, size(M)); + + typename Params::TMA_C tma_load_c = {}; + if constexpr (is_source_supported) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC)); + tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); + } + + typename Params::TMA_D tma_store_d; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD)); + tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, take<0,2>(SmemLayoutD{}), EpilogueTile{}, _1{}); + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + tma_load_c, + tma_store_d + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + if constexpr (is_destination_supported) { + constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + } + + if constexpr (not cute::is_void_v) { + constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideC{}); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + return implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void + prefetch_tma_descriptors(Params const& epilogue_params) { + if constexpr (is_source_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + } + if constexpr (is_destination_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor C under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Tile residue + auto residue_mn = make_coord(M,N); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + SmemElementC* ptr_sC = nullptr; + + if constexpr (is_source_supported) { + if constexpr (ReuseSmemC) { + ptr_sC = reinterpret_cast(shared_tensors.smem_D().data()); + } else { + ptr_sC = shared_tensors.smem_C().data(); + } + } + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + residue_mn, + EpilogueTile{}, + thread_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Acquire the lock for the first stage + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto mma_tile_m = tile_size<0>(tiled_mma); + auto mma_tile_n = tile_size<1>(tiled_mma); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + SmemElementC* ptr_sC = nullptr; + if constexpr (is_source_supported) { + if constexpr (ReuseSmemC) { + ptr_sC = reinterpret_cast(shared_tensors.smem_D().data()); + } else { + ptr_sC = shared_tensors.smem_C().data(); + } + } + + SmemElementD* ptr_sD = nullptr; + if constexpr (is_destination_supported) { + ptr_sD = shared_tensors.smem_D().data(); + } + + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg = recast>(tRS_rAcc); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); + Tensor cD = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); + Tensor tRS_cD = thread_r2s.partition_S(flat_divide(cD, EpilogueTile{})); + auto residue_mn = make_coord(M,N); + + CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + residue_mn, + EpilogueTile{}, + tiled_copy_C_atom, + thread_idx, + cD, + tRS_cD, + tRS_rC + }; + auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + // The current tile in accumulator + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + // Vectorized fragment loop with visitor callback entry point + int r2s_v = epi_n * size(tRS_rD_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { + tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration); + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp new file mode 100644 index 00000000..8eeb43c2 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing pipelined epilogues with bias add and elementwise activation functions. + This collective is now DEPRECATED, will be removed in the next release. Use EVT instead. +*/ + +#pragma once + +#include "sm90_epilogue_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) + class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class Sm90EpilogueTmaWarpSpecializedBiasElementwise + : public CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +private: + using Impl = + CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_ + >; +public: + using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; + using ElementCompute = typename Impl::ThreadEpilogueOp::ElementCompute; + using ElementBias = typename Impl::ThreadEpilogueOp::ElementBias; + using ElementT = typename Impl::ThreadEpilogueOp::ElementAux; + + // Constructor inheritance + using Impl::Impl; + + // Host side epilogue arguments + struct [[deprecated("use Sm90TmaWarpSpecialized Arguments instead")]] + Arguments { + struct ThreadArgs { + ElementCompute alpha{1}; + ElementCompute beta{0}; + ElementCompute const *alpha_ptr{nullptr}; + ElementCompute const *beta_ptr{nullptr}; + } thread; + ElementC_ const* ptr_C{nullptr}; + StrideC_ dC{}; + ElementD_* ptr_D{nullptr}; + StrideD_ dD{}; + ElementBias const* ptr_Bias{nullptr}; + ElementT* ptr_T{nullptr}; + + CUTLASS_HOST_DEVICE + operator typename Impl::Arguments() const { + typename Impl::Arguments arguments; + arguments.thread.alpha = thread.alpha; + arguments.thread.beta = thread.beta; + arguments.thread.alpha_ptr = thread.alpha_ptr; + arguments.thread.beta_ptr = thread.beta_ptr; + if constexpr (not cute::is_void_v) { + arguments.thread.bias_ptr = ptr_Bias; + } + if constexpr (not cute::is_void_v) { + arguments.thread.aux_ptr = ptr_T; + arguments.thread.dAux = dD; + } + arguments.ptr_C = ptr_C; + arguments.dC = dC; + arguments.ptr_D = ptr_D; + arguments.dD = dD; + + return arguments; + } + }; + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/dispatch_policy.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/dispatch_policy.hpp new file mode 100644 index 00000000..409ff74d --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/dispatch_policy.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue { + +////////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////////// +// +// Builder Epilogue Schedules +// +////////////////////////////////////////////////////////////////////////////// + +struct NoSmemWarpSpecialized {}; +struct PtrArrayNoSmemWarpSpecialized {}; +struct TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperative {}; +// DEPRECATED schedules, will be removed in next release +struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +struct TmaWarpSpecializedBiasElementwiseBase : public TmaWarpSpecialized{}; +struct TmaWarpSpecializedCooperativeBiasElementwiseBase : public TmaWarpSpecializedCooperative {}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + +////////////////////////////////////////////////////////////////////////////// +// +// Collective Dispatch Policies +// +////////////////////////////////////////////////////////////////////////////// + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm90TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + + +// DEPRECATED policies, will be removed in next release +template< + int StagesC_, + int StagesD_, + int FragmentSize_ = 2 +> +struct Sm90TmaWarpSpecializedBiasElementwise { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/callbacks.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/callbacks.hpp new file mode 100644 index 00000000..9ee37234 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/callbacks.hpp @@ -0,0 +1,89 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Dispatch interface for epilogue fusion callbacks +// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. +// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, +// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. +template < + class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm + class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination + class CtaTile_MNK, // computed tile per CTA + class EpilogueTile_MN, // epilogue subtile size + class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) +> +struct FusionCallbacks { + static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); +}; + +// Metadata helper to handle custom EVTs or other non-FusionCallbacks types +template +struct FusionCallbacksTraits { + using DispatchPolicy = void; + using Operation = T; + using CtaTile_MNK = void; + using EpilogueTile_MN = void; + using ElementCompute = void; +}; + +template < + class DispatchPolicy_, + class Operation_, + class CtaTile_MNK_, + class EpilogueTile_MN_, + class... Args +> +struct FusionCallbacksTraits< + FusionCallbacks +> { + using DispatchPolicy = DispatchPolicy_; + using Operation = Operation_; + using CtaTile_MNK = CtaTile_MNK_; + using EpilogueTile_MN = EpilogueTile_MN_; + using ElementCompute = typename Operation::ElementCompute; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/operations.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/operations.hpp new file mode 100644 index 00000000..f7a4e2d8 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/operations.hpp @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Fusion Operations +// Template args must not be implementation dependent +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct FusionOperation { + // metadata types/queries that can be overrided + using ElementOutput = void; + using ElementCompute = void; + + using ElementSource = void; + static constexpr bool IsSourceSupported = false; + + using ElementScalar = void; + static constexpr int AlignmentScalar = 0; + static constexpr bool IsScaleFactorSupported = false; + static constexpr bool IsPerRowScaleSupported = false; + using ElementBias = void; + static constexpr int AlignmentBias = 0; + static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsDePerRowBiasSupported = false; + + using ActivationFn = void; + static constexpr bool IsEltActSupported = false; + static constexpr bool IsDeEltActSupported = false; + + using ElementAux = void; + using GmemLayoutTagAux = void; + static constexpr int AlignmentAux = 0; + static constexpr bool IsAuxOutSupported = false; + static constexpr bool IsAuxInSupported = false; + + using ElementAmax = void; + static constexpr bool IsAbsMaxSupported = false; + +}; + +// D = alpha * acc +template< + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAcc : FusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = 1; + static constexpr auto RoundStyle = RoundStyle_; +}; + +// D = alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinearCombination + : ScaledAcc { + using ElementSource = ElementSource_; + static constexpr bool IsSourceSupported = true; +}; + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombEltAct + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + + +// D = alpha * acc + beta * C + per-row bias +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltAct + : LinCombPerRowBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +// aux = alpha * acc + beta * C + per-row bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltActAux + : LinCombPerRowBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentScalar_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerRowLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerRowScaleSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltActAmaxAux + : ScaledLinCombPerRowBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltAct + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsDeEltActSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxInSupported = true; +}; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +// dBias = sum of columns of D +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltActDePerRowBias + : LinCombDeEltAct { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsDePerRowBiasSupported = true; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp new file mode 100644 index 00000000..b10dee87 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -0,0 +1,1402 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Sm90EVT = Sm90TreeVisitor; + +// D = alpha * acc +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledAcc, + CtaTileShapeMNK, + EpilogueTile +> : Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch + > { + using Impl = + Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch + >; + using Operation = fusion::ScaledAcc; + + struct Arguments { + // Give a name and flat ordering to the fusion callback args + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombination = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltAct { + + using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBias, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { + using Impl = Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + using Operation = fusion::LinCombPerRowBias< + ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltAct = + Sm90EVT, + Sm90LinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +// Aux = alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerRowBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = per-row alpha * acc + per-row beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, AlignmentScalar>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, AlignmentScalar>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + > + >; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBiasEltAct = + Sm90EVT, + Sm90PerRowLinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + using StrideAlpha = Stride<_1,_0,int>; + using StrideBeta = Stride<_1,_0,int>; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + StrideAlpha dAlpha = {}; + StrideBeta dBeta = {}; + + using StrideBias = Stride<_1,_0,int>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + +// We only apply the scaling factor if output is fp8 +template +struct ScaleOutOp { template using Op = cutlass::first; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; + +template +using amax = cutlass::maximum_absolute_value_reduction; // propogate nans + +}; // end namespace detail + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerRowBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{scale_c, beta}, + {scale_c_ptr, beta_ptr} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{scale_a, scale_b, alpha}, + {scale_a_ptr, scale_b_ptr, alpha_ptr} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z + +// fp8 aux specialization +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8 = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +// non-fp8 aux specialization +// lets us use some EVT specializations such as relu + uint1b_t aux +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90EVT, // Aux = Z + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias, + > + > + >, + Sm90ScalarBroadcast // scale_d + >; + +// dispatcher +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle + >, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > +>; + + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementAmax, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (detail::is_fp8_v) { + amax_D_ptr_ = amax_D_ptr; + } + + // Aux is fp8 -> DAG arguments + if constexpr (detail::is_fp8_v) { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{scale_c, beta}, + {scale_c_ptr, beta_ptr} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{scale_a, scale_b, alpha}, + {scale_a_ptr, scale_b_ptr, alpha_ptr} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + + // Aux is not fp8 -> Tree arguments + else { + return + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + { // unary op : store(Z) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{scale_c, beta}, + {scale_c_ptr, beta_ptr} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{scale_a, scale_b, alpha}, + {scale_a_ptr, scale_b_ptr, alpha_ptr} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias + }, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d},{scale_d_ptr}}, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc), aux) + Sm90LinearCombination, // beta * C + (alpha * acc) + Sm90AuxLoad // aux + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementSource, + class ElementScalar, + int AlignmentAux, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + using Operation = + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltActDePerRowBias = + Sm90EVT, // Identity for final conversion + Sm90EVT, AlignmentBias>, + Sm90LinCombDeEltAct + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + using StrideBias = Stride<_1,_0,int>; + ElementBias* dbias_ptr = nullptr; + StrideBias dDbias = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : identity/convert + { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }, // end binary op + {dbias_ptr, ElementCompute(0), dDbias} // unary args : reduce + }, // end unary op + {} // unary args : identity/convert + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template > +struct get_element_aux { + using type = void; +}; + +template +struct get_element_aux> { + using type = typename FusionOpOrCallbacks::ElementAux; +}; + +template +struct get_element_aux, cute::void_t<>> { + using type = typename get_element_aux::type; +}; + +template +struct get_element_aux, cute::void_t::Operation>> { + private: + using Operation = typename FusionCallbacks::Operation; + public: + using type = typename get_element_aux::type; +}; +} // namespace cutlass:epilogue::fusion::detail + +template +using get_element_aux_t = typename detail::get_element_aux::type; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp new file mode 100644 index 00000000..330c6351 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -0,0 +1,777 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree compute operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// N-nary Elementwise Compute Operation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// The template argument provided for ComputeFn must be able to accept +// exactly one template parameter. In Standard C++, it's OK for +// ComputeFn to have other template parameters, as long as those have +// defaults. For example, the following struct Foo would work. +// +// template +// struct Foo { +// CUTLASS_HOST_DEVICE auto operator() (A a, B b); +// }; +// +// However, some compilers, such as Clang, require that the argument +// take _exactly_ one template parameter. This is nonstandard C++ +// behavior. One work-around for this case is to create a subclass +// with exactly one template parameter, and then use that subclass as +// the template argument. +// +// template +// struct FooHomogeneous : public Foo {}; +// +template< + template class ComputeFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class = void +> +struct Sm90Compute { +private: + using EmptyArguments = typename Sm90VisitorImpl<>::Arguments; + + template + struct ComputeArguments { + using type = EmptyArguments; + }; + + // partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters + template + struct ComputeArguments> { + using type = typename Fn::Arguments; + }; + +public: + struct SharedStorage { }; + + using Arguments = typename ComputeArguments>::type; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const&, Arguments const&) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90Compute() { } + + CUTLASS_HOST_DEVICE + Sm90Compute(Params const& params, SharedStorage const& shared_storage) + : params(params) {} + + Params const params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Params const& params) + : params(params) {} + + Params const& params; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) { + return transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + return convert_input(frg_input); + }, + [&] (auto&&... cvt_frg_inputs) { + using ComputeOutput = ComputeFn>; + using ConvertOutput = NumericArrayConverter; + ComputeOutput compute_output{}; + ConvertOutput convert_output{}; + + if constexpr (cute::is_same_v) { + return convert_output(compute_output(cvt_frg_inputs...)); + } + else { + return convert_output(compute_output(cvt_frg_inputs..., params)); + } + } + ); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(params); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Performance Optimized Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// beta * C + Z +template < + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class InputScaleOp, // beta + class ElementSource, // C + class InputAddOp // Z +> +struct Sm90TreeVisitor< + Sm90Compute().is_zero())>>, + InputScaleOp, + Sm90SrcFetch, + InputAddOp +> : Sm90VisitorImpl< + InputScaleOp, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + InputScaleOp, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + auto const& added_op = get<2>(Impl::ops); + return is_C_load_needed() || added_op.is_producer_load_needed(); + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + auto const& scale_op = get<0>(Impl::ops); + auto const& src_op = get<1>(Impl::ops); + auto const& added_op = get<2>(Impl::ops); + return (not scale_op.is_zero() && src_op.is_C_load_needed()) || added_op.is_C_load_needed(); + } + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(bool is_C_load_needed, CallbacksImpl&& impl) + : is_C_load_needed(is_C_load_needed), CallbacksImpl(cute::forward(impl)) { } + + bool is_C_load_needed; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_added = get<2>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementZ = typename decltype(frg_added)::Element; + using ConvertZ = NumericArrayConverter; + using ConvertI = NumericArrayConverter; + ConvertZ convert_Z{}; + ConvertI convert_I{}; + + Array frg_I = convert_Z(frg_added); + + if (is_C_load_needed) { + Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementX = typename decltype(frg_scalar)::Element; + using ElementY = typename decltype(frg_source)::Element; + using ConvertX = NumericArrayConverter; + using ConvertY = NumericArrayConverter; + using ComputeI = multiply_add>; + ConvertX convert_X{}; + ConvertY convert_Y{}; + ComputeI compute_I{}; + + frg_I = compute_I(convert_X(frg_scalar), convert_Y(frg_source), frg_I); + } + + return convert_I(frg_I); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks( + is_C_load_needed(), std::move(callbacks_tuple)); + } +}; + +// ReLU with aux bit tensor dReLU/dZ +// Aux(i) = Z(i) >= 0 ? 1 : 0 +namespace detail { +// Placeholder node so we can retain standard EVT structure +template +struct Sm90ReLUAuxStore : Sm90VisitorImpl<> { + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { } +}; +} // namespace detail + +// Specialization on the generic compute+aux EVT +template < + // Compute node + template class Activation, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + // Aux node + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment, + bool EnableNullptr, + // Input node + class InputOp +> +struct Sm90TreeVisitor< + Sm90Compute, cutlass::epilogue::thread::ReLu> || + cute::is_same_v, cutlass::epilogue::thread::Clamp> >>, + Sm90TreeVisitor< + Sm90AuxStore< + Stages, + EpilogueTile, + cutlass::uint1b_t, + RoundStyle, + StrideMNL, + SmemLayoutAtom, + CopyOpR2S, + Alignment, + EnableNullptr + >, + InputOp + > +> : Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage) + : params(params_), Impl(params_, shared_storage) {} + + Params const& params; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + GTensor&& tC_gAux, + CTensor tC_cAux, + ResidueMN residue_mn, + Params const& params, + CallbacksImpl&& impl) + : tC_rAux(cute::forward(tC_rAux)), + tC_gAux(cute::forward(tC_gAux)), + tC_cAux(tC_cAux), + residue_mn(residue_mn), + params(params), + CallbacksImpl(cute::forward(impl)) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ResidueMN residue_mn; + Params const& params; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); + + // Compute activation + aux + using ElementInput = typename decltype(frg_input)::Element; + using ConvertInput = NumericArrayConverter; + using ConvertAux = PackPredicates; + using ComputeOutput = Activation; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput relu{}; + ConvertAux convert_aux{}; + ConvertOutput convert_output{}; + + Array frg_compute = convert_input(frg_input); + bool frg_aux[FragmentSize]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + ElementCompute pre_relu = frg_compute[i]; + if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp>) { + frg_compute[i] = relu(frg_compute[i], params_compute); + } + else { + frg_compute[i] = relu(frg_compute[i]); + } + frg_aux[i] = frg_compute[i] == pre_relu; + } + + static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); + Tensor tC_rAux_frg = recast(coalesce(tC_rAux(_,_,_,epi_m,epi_n))); // (EPI_V) + tC_rAux_frg(epi_v) = convert_aux(frg_aux); + + return convert_output(frg_compute); + } + + CUTLASS_DEVICE void + end() { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + callbacks_input.end(); + + // Nullptr is no-op + if constexpr (EnableNullptr) { + if (params_aux.ptr_aux == nullptr) { + return; + } + } + + // Copy vectorizes into byte-aligned stores + constexpr int V = cute::min(Alignment, decltype(max_common_vector(tC_rAux, tC_gAux))::value); + if constexpr (V > 0 && V % 8 == 0) { + using VecType = uint_bit_t; + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_cAux_vec = tC_cAux.compose(make_layout(Int{}, Int{})); // only works if vector is logically sequential + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_rAux_vec, tC_gAux_vec); + } + // sub-byte vectorization, must serialize threads + else { + // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) + int lane_idx = canonical_lane_idx(); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_mn); }; + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < NumThreadsPerWarp; ++i) { + if (lane_idx == i) { + copy_if(FunctionPredTensor(predicate_fn), tC_rAux, tC_gAux); + } + __syncwarp(); + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + // Unpack params + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params_aux.ptr_aux)); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + auto callbacks_impl = Impl::template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params, cute::move(callbacks_impl)); + } +}; + +// Aux load for uint1b_t +template < + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + Stages, + EpilogueTile, + cutlass::uint1b_t, + StrideMNL, + SmemLayoutAtom, + CopyOpS2R, + Alignment, + EnableNullptr +> { + static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); + + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t const* ptr_aux = nullptr; + cutlass::uint1b_t null_default = cutlass::uint1b_t(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const&) + : params(params) { } + + Params const params; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, ResidueMN residue_mn_, Params const& params_) + : tC_rAux(cute::forward(tC_rAux_)), + tC_gAux(cute::forward(tC_gAux_)), + residue_mn(residue_mn_), + params(params_) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ResidueMN residue_mn; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if constexpr (decltype(cute::rank(tC_rAux))::value == 5) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { // (partially) in-bounds CTA tile + copy_aligned(tC_gAux, tC_rAux); + } + } + } + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { + copy_aligned(tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); + } + } + } + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + using ElementRegister = typename remove_cvref_t::value_type; + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { + return recast>(coalesce(tC_rAux))(epi_v); + } + else { + return recast>(coalesce(tC_rAux(_,_,_,epi_m,epi_n)))(epi_v); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params.ptr_aux)); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + + // If byte-unaligned vectorization, store in registers as uint32_t to reduce redundant pack+unpack instruction sequences + constexpr int V = decltype(max_common_vector(tC_gAux.layout(), make_layout(tC_gAux.shape())))::value; + Tensor tC_rAux = [&] () { + if constexpr (V % 8 != 0) { + return make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + } else { + return make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + } + }(); + + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + fill(tC_rAux, params.null_default); + } + } + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params); + } +}; + +// dReLU specialization +template< + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle +> +struct Sm90Compute< + cutlass::epilogue::thread::dReLU, + ElementOutput, + ElementCompute, + RoundStyle +> : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input, + Array const& frg_aux) { + using ConvertInput = NumericArrayConverter; + using ComputeOutput = cutlass::epilogue::thread::dReLU>; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput compute_output{}; + ConvertOutput convert_output{}; + + return convert_output(compute_output(convert_input(frg_input), frg_aux)); // don't convert frg_aux for dReLU + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp new file mode 100644 index 00000000..1ea663f6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -0,0 +1,893 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree load operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Fetch Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns accumulator +struct Sm90AccFetch : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return frg_acc; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks{}; + } +}; + +// Split tree visitor fetches intermediate results from temporary accumulators +using Sm90SplitTreeFetch = Sm90AccFetch; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns C +template +struct Sm90SrcFetch : Sm90VisitorImpl<> { + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return is_C_load_needed(); + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return not is_void_v; + } + + CUTLASS_DEVICE bool + is_zero() const { + return is_void_v; + } + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(SrcTensor const& tCrC) + : tCrC(tCrC) {} + + SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N) + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tCrC)(epi_v); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + // register type may differ from logical type so we can't assert matching types here + return ConsumerStoreCallbacks(args.tCrC); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90AuxLoad { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using CopyOpG2S = + SM90_TMA_LOAD + ; + + struct SharedStorage { + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; + }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), + take<0,2>(SmemLayoutTma{}))); + TMA_Aux tma_load_aux; + Element null_default = Element(0); + bool use_default = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto M_AUX = + size(M) + ; + Tensor tensor_aux = make_tensor(make_gmem_ptr(args.ptr_aux), make_layout(make_shape(M_AUX,N,L), append<3>(args.dAux, _0{}))); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(CopyOpG2S{}, tensor_aux, take<0,2>(SmemLayoutTma{})); + + bool use_default = false; + if constexpr (EnableNullptr) { + use_default = args.ptr_aux == nullptr; + } + + return Params{tma_load_aux, args.null_default, use_default}; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (params_ptr->use_default && params_ptr->null_default == Element(0)); + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& bGS_gAux, STensor&& bGS_sAux, Params const* params_ptr) + : bGS_gAux(cute::forward(bGS_gAux)), + bGS_sAux(cute::forward(bGS_sAux)), + params_ptr(params_ptr) {} + + GTensor bGS_gAux; // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + STensor bGS_sAux; // (TMA,TMA_M,TMA_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + return; + } + } + + if (issue_tma_load) { + // Increment the expected transaction bytes of the current stage's mbarrier by the subtile's byte-size + constexpr uint32_t copy_bytes = size(take<0,2>(SmemLayout{})) * sizeof_bits_v / 8; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA load + constexpr uint16_t mcast_mask = 0; + int load_pipe_index = load_iteration % Stages; + copy(params_ptr->tma_load_aux.with(*full_mbarrier_ptr, mcast_mask), + bGS_gAux(_,_,_,epi_m,epi_n), bGS_sAux(_,_,_,load_pipe_index)); + } + } + }; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + auto coord_shape = + make_coord(m, n, l) + ; + Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + + ThrCopy thrblk_g2s = params_ptr->tma_load_aux.get_slice(_0{}); + Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + + return ProducerLoadCallbacks( + cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux, TiledS2R tiled_s2r, STensorS2R&& tSR_sAux, Params const* params_ptr) + : tC_rAux(cute::forward(tC_rAux)), + tiled_s2r(tiled_s2r), + tSR_sAux(cute::forward(tSR_sAux)), + params_ptr(params_ptr) { } + + TiledS2R tiled_s2r; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorS2R tSR_sAux; // (S2R,S2R_M,S2R_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + + using RLayoutS2R = decltype(cute::layout(TiledS2R{}.get_slice(0).retile_S(RTensor{}))); + Tensor tSR_rAux = make_tensor(tC_rAux.data(), RLayoutS2R{}); // (S2R,S2R_M,S2R_N) + + int load_pipe_index = load_iteration % Stages; + copy(tiled_s2r, tSR_sAux(_,_,_,load_pipe_index), tSR_rAux); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + + return tC_rAux_frg(epi_v); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + + Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); + Tensor tC_gAux = sm90_partition_for_epilogue(mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + auto tiled_s2r = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) + auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Broadcast Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors +template< + class Element, + class StrideMNL = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcast { + static_assert( + (cute::is_same_v>) || // scalar broadcast, e.g. alpha + (cute::is_same_v>) || // batched scalar broadcast, e.g. per-batch alpha + (cute::is_same_v>)); + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + StrideMNL dScalar = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + // This must be called after update_scalar is called + CUTLASS_DEVICE bool + is_zero() const { + return scalar == Element(0); + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if (get<2>(params_ptr->dScalar) == 0) { + update_scalar(); + } + } + + Element scalar; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + int l_offset = l_coord * size<2>(params_ptr->dScalar); + + if (params_ptr->scalar_ptrs[0] != nullptr) { + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } else { + // batch stride is ignored for nullptr fallback + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + if (params_ptr->scalar_ptrs[i] != nullptr) { + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][l_offset]); + } else { + // batch stride is ignored for nullptr fallback + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } + + template + CUTLASS_DEVICE void + update_scalar(cute::tuple) { + // Only support multiple L-modes with fully-broadcast scalar + static_assert(cute::is_same_v>); + scalar = params_ptr->scalars[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +constexpr int +compute_row_broadcast_stages() { + return ceil_div(StagesC, size<1>(zipped_divide(make_layout(take<0,2>(CtaTileShapeMNK{})), EpilogueTile{}))) + 1; +} + +} + +// Row vector broadcast +template< + // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least + // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90RowBroadcast { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias + (cute::is_same_v>)); // batched row vector broadcast + + // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem + struct SharedStorage { + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + }; + + struct Arguments { + Element const* ptr_row = nullptr; + Element null_default = Element(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params), + smem_row(const_cast(shared_storage.smem_row.data())) { } + + Params params; + Element* smem_row; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (params.ptr_row == nullptr && params.null_default == Element(0)); + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) + : gRow(cute::forward(gRow)), + sRow(cute::forward(sRow)), + params(params) {} + + GTensor gRow; // (CTA_M,CTA_N) + STensor sRow; // (CTA_M,CTA_N,PIPE) + Params const& params; + + CUTLASS_DEVICE void + begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return; + } + } + + if (issue_tma_load) { + // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size + constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA bulk copy + auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); + // Filter so we don't issue redundant copies over stride-0 modes + int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; + copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); + } + } + }; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + + constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ProducerLoadCallbacks( + cute::move(gRow), cute::move(sRow), params); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) + : tCrRow(cute::forward(tCrRow)), + tCsRow(cute::forward(tCsRow)), + params(params) {} + + RTensor tCrRow; // (CPY,CPY_M,CPY_N) + STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + Params const& params; + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + fill(tCrRow, params.null_default); + return; + } + } + + if (epi_m == 0) { // Assumes M-major subtile loop + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; + copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tCrRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) + + constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ConsumerStoreCallbacks( + cute::move(tCrRow), cute::move(tCsRow), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90ColBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + struct Arguments { + Element const* ptr_col = nullptr; + Element null_default = Element(0); + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (params.ptr_col == nullptr && params.null_default == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) + : tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + + CUTLASS_DEVICE void + begin() { + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + fill(tCrCol, params.null_default); + return; + } + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks( + cute::move(tCgCol), cute::move(tCrCol), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix broadcast +// Only need to redefine this if we can multicast across cluster L +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +using Sm90MatrixBroadcast + = Sm90AuxLoad; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp new file mode 100644 index 00000000..c8d941b6 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -0,0 +1,1405 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + FloatRoundStyle RoundStyle, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90AuxStore { + using ElementAux = Element; + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + struct SharedStorage { + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; + }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(static_cast(nullptr), repeat_like(StrideMNL{}, int32_t(0)), StrideMNL{}), + SmemLayoutTma{})); + TMA_Aux tma_store_aux; + bool is_nullptr = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + bool is_nullptr = false; + if constexpr (EnableNullptr) { + is_nullptr = args.ptr_aux == nullptr; + } + + typename Params::TMA_Aux tma_store_aux; + if (not is_nullptr) { + Tensor tensor_aux = make_tensor(args.ptr_aux, make_layout(make_shape(M,N,L), args.dAux)); + tma_store_aux = make_tma_copy(SM90_TMA_STORE{}, tensor_aux, SmemLayoutTma{}); + } + + return {tma_store_aux, is_nullptr}; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class TiledR2S, + class STensorR2S, + class STensorS2G, + class GTensorS2G + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + TiledR2S tiled_r2s, + STensorR2S&& tRS_sAux, + STensorS2G&& bSG_sAux, + GTensorS2G&& bSG_gAux, + Params const* params_ptr) + : tiled_r2s(tiled_r2s), + tC_rAux(cute::forward(tC_rAux)), + tRS_sAux(cute::forward(tRS_sAux)), + bSG_sAux(cute::forward(bSG_sAux)), + bSG_gAux(cute::forward(bSG_gAux)), + params_ptr(params_ptr) {} + + TiledR2S tiled_r2s; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorR2S tRS_sAux; // (R2S,R2S_M,R2S_N,PIPE) + STensorS2G bSG_sAux; // (S2G,S2G_M,S2G_N,PIPE) + GTensorS2G bSG_gAux; // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + using RLayoutR2S = decltype(cute::layout(TiledR2S{}.get_slice(0).retile_S(RTensor{}))); + Tensor tRS_rAux = make_tensor(tC_rAux.data(), RLayoutR2S{}); // (R2S,R2S_M,R2S_N) + + if (issue_smem_store) { + int store_pipe_index = store_iteration % Stages; + copy(tiled_r2s, tRS_rAux, tRS_sAux(_,_,_,store_pipe_index)); + } + } + + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + if (issue_tma_store) { + // Issue the TMA store + int store_pipe_index = store_iteration % Stages; + copy(params_ptr->tma_store_aux, bSG_sAux(_,_,_,store_pipe_index), bSG_gAux(_,_,_,epi_m,epi_n)); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + auto tRS_sAux = tiled_r2s.get_slice(args.thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) + + ThrCopy thrblk_s2g = params_ptr->tma_store_aux.get_slice(_0{}); + Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), + tiled_r2s, + cute::move(tRS_sAux), + cute::move(bSG_sAux), + cute::move(bSG_gAux), + params_ptr); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Reduction Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar reduction +template < + template class RegReduceFn, + template class GmemReduceFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_0,_0>, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90ScalarReduction { +private: + static_assert( + (cute::is_same_v>) || // scalar reduction, e.g. tensor max element + (cute::is_same_v>) || // batched scalar reduction, e.g. per-batch max element + (cute::is_same_v>)); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); + +public: + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_scalar = nullptr; + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dScalar = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + if constexpr (IsAtomic) { + auto [M, N, K, L] = problem_shape; + Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); + if (args.ptr_scalar != nullptr) { + return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); + } + } + + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params const params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + int l_coord, + CTensor tCcScalar, + ResidueMN residue_mn, + Params const& params) + : scalar(params.reduction_identity), + l_coord(l_coord), + tCcScalar(tCcScalar), + residue_mn(residue_mn), + params(params) {} + + ElementCompute scalar; + int l_coord; + CTensor tCcScalar; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ResidueMN residue_mn; + Params params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return frg_input; + } + } + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + Tensor tCcScalar_mn = tCcScalar(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_mn)) { + scalar = reduce_input(scalar, frg_I[i]); + } + } + + return frg_input; + } + + CUTLASS_DEVICE void + end() { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return; + } + } + + using ConvertI = NumericConverter; + using ReduceInput = GmemReduceFn; + + ConvertI convert_I{}; + ReduceInput reduce_input{}; + + ElementOutput* ptr_scalar = params.ptr_scalar + l_coord * get<2>(params.dScalar); + reduce_input(ptr_scalar, convert_I(scalar)); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks( + get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_mn, params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Row vector reduction +template < + template class RegReduceFn, + template class ShuffleReduceFn, + template class GmemReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true, // Noop on nullptr params + // If this is false, ptr_row is assumed to point to a compact n-major (ceil_div(M,CTA_M), round_nearest(N,CTA_N), L) + // tensor of ElementCompute. It is the user's responsibility to reduce this to a (N, L) tensor of ElementOutput + bool FinalReduction = true, + // False means skip OOB predication if OOB inputs are known to be the reduction identity + bool VisitCheckOOB = true +> +struct Sm90RowReduction { +private: + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // row vector reduction, e.g. per-col sum over all batches + (cute::is_same_v>)); // batched row vector reduction, e.g. per-col sum per batch + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); + +public: + struct SharedStorage { }; + + struct Arguments { + void* ptr_row = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* + ElementCompute reduction_identity = 0; + StrideMNL dRow = {}; + }; + + struct Params { + void* ptr_row = nullptr; + ElementCompute reduction_identity = 0; + StrideMNL dRow = {}; + ElementCompute* reduction_buffer = nullptr; + int* tile_counters = nullptr; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + ElementCompute* reduction_buffer; + int* tile_counters = nullptr; + if constexpr (IsAtomic) { + reduction_buffer = nullptr; + } + else if constexpr (not FinalReduction) { + reduction_buffer = reinterpret_cast(args.ptr_row); + } + else { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); + + reduction_buffer = reinterpret_cast(workspace); + tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + } + + return { + args.ptr_row, + args.reduction_identity, + args.dRow, + reduction_buffer, + tile_counters + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + if constexpr (IsAtomic || not FinalReduction) { + return 0; + } + + size_t workspace_size = 0; + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Increment by size of reduction buffer + workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + // Align and increment by size of tile counters + workspace_size = round_nearest(workspace_size, sizeof(int)); + workspace_size += cute::ceil_div(size<>(N), tile_N) * sizeof(int); + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + if constexpr (IsAtomic) { + auto [M, N, K, L] = problem_shape; + Layout mRow_layout = make_layout(make_shape(M,N,L), args.dRow); + if (args.ptr_row != nullptr) { + return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); + } + return Status::kSuccess; + } + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(size<>(N), tile_N) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream); + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90RowReduction() { } + + CUTLASS_HOST_DEVICE + Sm90RowReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + bool do_final_reduction = false; + + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return frg_input; + } + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); + Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if constexpr (VisitCheckOOB) { + if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_mn)) { + ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); + tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); + } + } + else { + ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); + tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { + if (not is_last_iteration) { + return; + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + auto [m, n, k, l] = tile_coord_mnkl; + constexpr bool ReferenceSrc = decltype(ref_src)::value; + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return; + } + } + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cRow(_0{},_0{}), residue_mn)) { + return; + } + + // + // 1. Warp shuffle reduction + // + using FragmentShuffle = Array; + using ReduceShuffle = ShuffleReduceFn; + ReduceShuffle reduce_shuffle{}; + Tensor tCrRow_frg = recast(filter(tCrRow)); + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); + tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); + } + } + bool is_reduced_lane = get<0>(lane_mn) == 0; + + // + // 2. Atomic reduction + // + if constexpr (IsAtomic) { + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrRow_flt = filter_zeros(tCrRow); + Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCcRow.stride())); + + Tensor tCgRow = sm90_partition_for_epilogue(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx); + Tensor tCgRow_flt = filter_zeros(tCgRow); + // NOTE: atomic reduction is performed in the output type + using ConvertOutput = NumericConverter; + using ReduceOutput = GmemReduceFn; + ConvertOutput convert_output{}; + ReduceOutput reduce_output{}; + + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrRow_flt); ++i) { + if (elem_less(tCcRow_flt(i), residue_mn)) { + reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); + } + } + } + sync_fn(); + } + + // + // 2. One warp in M, skip threadblock smem reduction + // + else if constexpr (decltype(size<0>(warp_layout_MN))::value <= 1) { + // Dump warp reduction to gmem workspace + using ElementGmem = cute::conditional_t; + Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_ml(_,_,m,l), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrRow), recast(filter(tCgBuf))); + } + sync_fn(); + } + + // + // 2. Multiple warps in M, do threadblock smem reduction + // + else { + Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); + static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= + decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), + "smem reduction buffer not large enough, use a larger epilogue tile"); + + // Dump warp reduction to smem workspace + Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + // Filter so we don't issue redunant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrRow), filter(tCsBuf)); + } + sync_fn(); + + constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); + using FragmentSmem = Array; + using VectorSmem = uint_bit_t>; + using ReduceSmem = GmemReduceFn; + ReduceSmem reduce_smem{}; + + Tensor sBuf_frg = recast(filter_zeros(sBuf)); + Tensor sBuf_vec = recast(filter_zeros(sBuf)); + constexpr int FragsPerRow = decltype(size<1>(sBuf_frg))::value; + + // Do the threadblock smem reduction + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = size<0>(warp_layout_MN) / 2; reduction_rows > 1; reduction_rows /= 2) { + int FragsPerReduction = reduction_rows * FragsPerRow; + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); + sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // Do final smem reduction and dump to gmem workspace + using VectorGmem = cute::conditional_t; + Tensor gBuf_vec = recast(filter(gBuf_ml(_,_,m,l))); + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerRow; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerRow)); + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // + // 3. Increment atomic counters to signal final gmem reduction + // + if constexpr (not IsAtomic && FinalReduction) { + // Ensure gmem writes are visible to other threads before incrementing counter + __threadfence(); + sync_fn(); + // Collective thread 0 increments atomic tile counter and copies value to smem + int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); + if (thread_idx == 0) { + *prev_tile_count = atomicAdd(¶ms.tile_counters[n], 1); + } + sync_fn(); + // Broadcast tile count to other threads in CTA and determine final reduction status + do_final_reduction = *prev_tile_count == size<2>(gBuf_ml) * size<3>(gBuf_ml) - 1; + sync_fn(); + } + } + + CUTLASS_DEVICE void + end() { + // + // 4. Do final gmem reduction if necessary + // + if constexpr (not IsAtomic && FinalReduction) { + if (not do_final_reduction) { + return; + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + + using ReduceOutput = GmemReduceFn; + using ConvertOutput = NumericConverter; + ReduceOutput reduce_output{}; + ConvertOutput convert_output{}; + + // Reduction over batches + if (size<2>(stride(gRow_l)) == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { + Tensor tRgBuf_ml = gBuf_ml(_0{},n,_,_); + ElementCompute output = tRgBuf_ml(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int ml = 1; ml < size(tRgBuf_ml); ++ml) { + output = reduce_output(output, tRgBuf_ml(ml)); + } + if (elem_less(cRow(_0{},n), residue_mn)) { + gRow_l(_0{},n,_0{}) = convert_output(output); + } + } + } + // No reduction over batches + else { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { + bool do_store = elem_less(cRow(_0{},n), residue_mn); + CUTLASS_PRAGMA_NO_UNROLL + for (int l = 0; l < size<3>(gBuf_ml); ++l) { + Tensor tRgBuf_m = gBuf_ml(_0{},n,_,l); + ElementCompute output = tRgBuf_m(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 1; m < size(tRgBuf_m); ++m) { + output = reduce_output(output, tRgBuf_m(m)); + } + if (do_store) { + gRow_l(_0{},n,l) = convert_output(output); + } + } + } + } + + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn + + int warp_idx = args.thread_idx / NumThreadsPerWarp; + auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); + + // Partition output gmem and register tensors + auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); // (M,N,L) + Tensor gRow_l = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) + Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gRow_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + fill(tCrRow, params.reduction_identity); + + // Partition gmem+smem reduction buffer tensors + Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_0{}, _1{})); + auto block_shape = ceil_div(make_shape(M,N,L), shape(gBuf_layout)); // (M_CNT, N_CNT, L_CNT) + + // Let the M_CNT (the num of partial reduction results) become the outer mode + Layout block_layout = make_layout(block_shape, make_stride(get<1>(block_shape), _1{}, get<0>(block_shape) * get<1>(block_shape))); + Layout mBuf_layout = blocked_product(gBuf_layout, block_layout); + Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) + Tensor gBuf_ml = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(_,n,_)); // (CTA_M,CTA_N,REST_M,L) + Layout sBuf_layout = blocked_product(gBuf_layout, // (CTA_M,CTA_N,WARPS_M) + make_layout(make_shape(_1{},_1{},size<0>(warp_layout_MN)))); + + auto args_tuple = make_tuple( + bool_constant{}, cute::move(tCrRow), args.tCcD, gRow_l, args.cD, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx); + return ConsumerStoreCallbacks(cute::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Col vector reduction +template < + template class RegReduceFn, + template class ShuffleReduceFn, + template class GmemReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true, // Noop on nullptr params + // If this is false, ptr_col is assumed to point to a compact m-major (round_nearest(M,CTA_M), ceil_div(N,CTA_N), L) + // tensor of ElementCompute. It is the user's responsibility to reduce this to a (M, L) tensor of ElementOutput + bool FinalReduction = true, + // False means skip OOB predication if OOB inputs are known to be the reduction identity + bool VisitCheckOOB = true +> +struct Sm90ColReduction { +private: + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector reduction, e.g. per-row sum over all batches + (cute::is_same_v>)); // batched col vector reduction, e.g. per-row sum per batch + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); + +public: + struct SharedStorage { }; + + struct Arguments { + void* ptr_col = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* + ElementCompute reduction_identity = 0; + StrideMNL dCol = {}; + }; + + struct Params { + void* ptr_col = nullptr; + ElementCompute reduction_identity = 0; + StrideMNL dCol = {}; + ElementCompute* reduction_buffer = nullptr; + int* tile_counters = nullptr; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + ElementCompute* reduction_buffer; + int* tile_counters = nullptr; + if constexpr (IsAtomic) { + reduction_buffer = nullptr; + } + else if constexpr (not FinalReduction) { + reduction_buffer = reinterpret_cast(args.ptr_col); + } + else { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); + + reduction_buffer = reinterpret_cast(workspace); + tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + } + + return { + args.ptr_col, + args.reduction_identity, + args.dCol, + reduction_buffer, + tile_counters + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + if constexpr (IsAtomic || not FinalReduction) { + return 0; + } + + size_t workspace_size = 0; + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + + // Increment by size of reduction buffer + workspace_size += product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + // Align and increment by size of tile counters + workspace_size = round_nearest(workspace_size, sizeof(int)); + workspace_size += cute::ceil_div(M, tile_M) * sizeof(int); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + if constexpr (IsAtomic) { + auto [M, N, K, L] = problem_shape; + Layout mCol_layout = make_layout(make_shape(M,N,L), args.dCol); + if (args.ptr_col != nullptr) { + return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); + } + return Status::kSuccess; + } + + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); + + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream); + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ColReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + bool do_final_reduction = false; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return frg_input; + } + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if constexpr (VisitCheckOOB) { + if (elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_mn)) { + ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); + tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); + } + } + else { + if (elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_mn)) { + ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); + tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); + } + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { + if (not is_last_iteration) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + auto [m, n, k, l] = tile_coord_mnkl; + constexpr bool ReferenceSrc = decltype(ref_src)::value; + + // Runtime nullptr is noop + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return; + } + } + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_mn)) { + return; + } + + // + // 1. Warp shuffle reduction + // + using FragmentShuffle = Array; + using ReduceShuffle = ShuffleReduceFn; + ReduceShuffle reduce_shuffle{}; + Tensor tCrCol_frg = recast(filter(tCrCol)); + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrCol_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrCol_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(_0{},reduction_cols)); + tCrCol_frg(frg_idx) = reduce_shuffle(tCrCol_frg(frg_idx), reinterpret_cast(frg_shfl)); + } + } + bool is_reduced_lane = get<1>(lane_mn) == 0; + + // + // 2. Atomic reduction + // + if constexpr (IsAtomic) { + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrCol_flt = filter_zeros(tCrCol); + Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); + + Tensor tCgCol = sm90_partition_for_epilogue(gCol_l(_,_,l), epi_tile, tiled_copy, thread_idx); + Tensor tCgCol_flt = filter_zeros(tCgCol); + + // NOTE: atomic reduction is performed in the output type + using ConvertOutput = NumericConverter; + using ReduceOutput = GmemReduceFn; + ConvertOutput convert_output{}; + ReduceOutput reduce_output{}; + + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrCol_flt); ++i) { + if (elem_less(tCcCol_flt(i), residue_mn)) { + reduce_output(&tCgCol_flt(i), convert_output(tCrCol_flt(i))); + } + } + } + sync_fn(); + } + + // + // 2. One warp in N, skip threadblock smem reduction + // + else if constexpr (decltype(size<1>(warp_layout_MN))::value <= 1) { + // Dump warp reduction to gmem workspace + using ElementGmem = cute::conditional_t; + Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrCol), recast(filter(tCgBuf))); + } + sync_fn(); + } + + // + // 2. Multiple warps in N, do threadblock smem reduction + // + else { + Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); + static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= + decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), + "smem reduction buffer not large enough, use a larger epilogue tile"); + + // Dump warp reduction to smem workspace + Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + // Filter so we don't issue redunant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrCol), filter(tCsBuf)); + } + sync_fn(); + + constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); + using FragmentSmem = Array; + using VectorSmem = uint_bit_t>; + using ReduceSmem = GmemReduceFn; + ReduceSmem reduce_smem{}; + + Tensor sBuf_frg = recast(filter_zeros(sBuf)); + Tensor sBuf_vec = recast(filter_zeros(sBuf)); + constexpr int FragsPerCol = decltype(size<0>(sBuf_frg))::value; + + // Do the threadblock smem reduction + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(warp_layout_MN) / 2; reduction_cols > 1; reduction_cols /= 2) { + int FragsPerReduction = reduction_cols * FragsPerCol; + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); + sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // Do final smem reduction and dump to gmem workspace + using VectorGmem = cute::conditional_t; + Tensor gBuf_vec = recast(filter(gBuf_nl(_,_,n,l))); + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerCol; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerCol)); + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // + // 3. Increment atomic counters to signal final gmem reduction + // + if constexpr (not IsAtomic && FinalReduction) { + // Ensure gmem writes are visible to other threads before incrementing counter + __threadfence(); + sync_fn(); + // Collective thread 0 increments atomic tile counter and copies value to smem + int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); + if (thread_idx == 0) { + *prev_tile_count = atomicAdd(¶ms.tile_counters[m], 1); + } + sync_fn(); + // Broadcast tile count to other threads in CTA and determine final reduction status + do_final_reduction = *prev_tile_count == size<2>(gBuf_nl) * size<3>(gBuf_nl) - 1; + sync_fn(); + } + } + + CUTLASS_DEVICE void + end() { + // + // 4. Do final gmem reduction if necessary + // + if constexpr (not IsAtomic && FinalReduction) { + if (not do_final_reduction) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + + using ReduceOutput = GmemReduceFn; + using ConvertOutput = NumericConverter; + ReduceOutput reduce_output{}; + ConvertOutput convert_output{}; + + // Reduction over batches + if (size<2>(stride(gCol_l)) == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + Tensor tRgBuf_nl = gBuf_nl(m,_0{},_,_); + ElementCompute output = tRgBuf_nl(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int nl = 1; nl < size(tRgBuf_nl); ++nl) { + output = reduce_output(output, tRgBuf_nl(nl)); + } + if (elem_less(cCol(m,_0{}), residue_mn)) { + gCol_l(m,_0{},_0{}) = convert_output(output); + } + } + } + // No reduction over batches + else { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + bool do_store = elem_less(cCol(m,_0{}), residue_mn); + CUTLASS_PRAGMA_NO_UNROLL + for (int l = 0; l < size<3>(gBuf_nl); ++l) { + Tensor tRgBuf_n = gBuf_nl(m,_0{},_,l); + ElementCompute output = tRgBuf_n(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 1; n < size(tRgBuf_n); ++n) { + output = reduce_output(output, tRgBuf_n(n)); + } + if (do_store) { + gCol_l(m,_0{},l) = convert_output(output); + } + } + } + } + + } + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn + int warp_idx = args.thread_idx / NumThreadsPerWarp; + auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); + + // Partition output gmem and register tensors + auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); // (M,N,L) + Tensor gCol_l = local_tile(mCol, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gCol_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + fill(tCrCol, params.reduction_identity); + + // Partition gmem+smem reduction buffer tensors + Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_1{}, _0{})); + Layout mBuf_layout = blocked_product(gBuf_layout, make_layout(ceil_div(make_shape(M,N,L), shape(gBuf_layout)))); + Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) + Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) + Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) + + auto args_tuple = make_tuple( + bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx); + return ConsumerStoreCallbacks(std::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix reduction +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class CopyOpR2S, + class SmemLayoutAtom, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90MatrixReduction; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp new file mode 100644 index 00000000..1e07cc89 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -0,0 +1,1066 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree operation base implementation to enable composable fusions + for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using cute::tuple; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partitioning Helpers +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class CtaTileMN, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + CtaTileMN cT, // (CTA_M,CTA_N,...) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + ThrCopy thread_copy = tiled_copy.get_thread_slice(thread_idx); + Tensor cT_epi = flat_divide(cT, epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) + if constexpr (ReferenceSrc) { + return thread_copy.partition_S(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } + else { + return thread_copy.partition_D(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } +} + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class Engine, class LayoutMNL, + class TileShapeMNK, + class TileCoordMNKL, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + Tensor mT, // (M,N,L) + TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) + TileCoordMNKL tile_coord_mnkl, // (m,n,k,l) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + auto [m, n, k, l] = tile_coord_mnkl; + auto coord_shape = + make_coord(m, n, l) + ; + Tensor cT = local_tile(mT, take<0,2>(tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) + Tensor tCcT = + sm90_partition_for_epilogue(cT, epi_tile, tiled_copy, thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return tCcT; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Visitor Implementation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class ResidueMN, + class EpilogueTile +> +struct ProducerLoadArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + ResidueMN residue_mn; + EpilogueTile epi_tile; + int thread_idx; + + CUTLASS_DEVICE + ProducerLoadArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + ResidueMN residue_mn, + EpilogueTile epi_tile, + int thread_idx) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + residue_mn(residue_mn), + epi_tile(epi_tile), + thread_idx(thread_idx) {} +}; + +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class ResidueMN, + class EpilogueTile, + class TiledCopy, + class CoordTensor, + class ThrCoordTensor, + class ThrSrcTensor +> +struct ConsumerStoreArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + ResidueMN residue_mn; + EpilogueTile epi_tile; + TiledCopy tiled_copy; + int thread_idx; + CoordTensor cD; + ThrCoordTensor tCcD; + ThrSrcTensor const& tCrC; + + CUTLASS_DEVICE + ConsumerStoreArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + ResidueMN residue_mn, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + CoordTensor cD, + ThrCoordTensor tCcD, + ThrSrcTensor const& tCrC) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + residue_mn(residue_mn), + epi_tile(epi_tile), + tiled_copy(tiled_copy), + thread_idx(thread_idx), + cD(cD), + tCcD(tCcD), + tCrC(tCrC) {} +}; + +template +struct Sm90VisitorImplBase { + // Shared memory allocation + using SharedStorage = tuple; + // Host side fusion arguments + using Arguments = tuple; + // Device side fusion params (Kernel-entry API) + using Params = tuple; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + uint8_t* op_workspace = reinterpret_cast(workspace); + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + auto ret = Op::to_underlying_arguments(problem_shape, op_args, op_workspace); + if (op_workspace != nullptr) { + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); + } + return ret; + }, + [] (auto&&... op_params) { return cute::make_tuple(op_params...); } + ); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + return round_nearest(op_workspace_size, MinWorkspaceAlignment); + }, + [&] (auto&&... op_workspace_size) { + return (0 + ... + op_workspace_size); + } + ); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* op_workspace = reinterpret_cast(workspace); + return transform_apply(tuple{}, args, + // Initialize each operation's workspace, stopping at the first error + [&] (auto&& op, auto const& op_args) { + if (status != Status::kSuccess) { + return status; + } + + using Op = cute::remove_cvref_t; + status = Op::initialize_workspace(problem_shape, op_args, op_workspace, stream, cuda_adapter); + if (op_workspace != nullptr) { + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); + } + return status; + }, + // Return the final status + [&] (auto const&...ops) { return status; } + ); + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops(transform_apply(tuple{}, params, shared_storage, + [] (auto&& op, auto const& op_params, auto&& op_storage) { + using Op = cute::remove_cvref_t; + return Op(op_params, op_storage); + }, + [] (auto&&... ops) { return cute::make_tuple(ops...); } + )) {} + + // Ops can store kernel persistent variables (e.g. descriptors, scalars, wave counters) + tuple ops; +}; + + +template +struct Sm90VisitorImpl : Sm90VisitorImplBase { + + using Impl = Sm90VisitorImplBase; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + using Impl::ops; + + // + // Queries for kernel runtime + // + + // Is a specialized warp for producer TMA loads needed + // e.g. Aux tensor loads, broadcasts using TMA bulk copy + // This condition cannot change between work tiles because it is used + // to determine whether the load warp should exit early or not + // e.g. for batched beta this must always be true regardless of current batch idx + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return cute::apply(ops, + [] (auto const&... op) { + return (false || ... || op.is_producer_load_needed()); + } + ); + } + + // Is a producer TMA load specifically for C needed + // If this is true then is_producer_load_needed must also be true + // This condition can change between work tiles because it is only used + // to determine whether the TMA and smem loads for C of a given tile should happen + // e.g. for batched beta this can be false depending on current batch idx + CUTLASS_DEVICE bool + is_C_load_needed() const { + return cute::apply(ops, + [] (auto const&... op) { + return (false || ... || op.is_C_load_needed()); + } + ); + } + + // + // Producer load callbacks, called by the epilogue load warp. + // Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation + // Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but + // are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead + // If this is non-empty, is_producer_load_needed must be true. + // + template + struct ProducerLoadCallbacks { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of the subtile load loop. Bulk copies usually performed here. + // Upon entry the producer_acquire of the first subtile lock has completed. + // full_mbarrier_ptr is the corresponding barrier for the subsequent producer_commit arrival + CUTLASS_DEVICE void + begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.begin(full_mbarrier_ptr, load_iteration, issue_tma_load); + } + ); + } + + // Entry of the subtile load loop. Aux loads usually performed here + // Upon entry the producer acquire of the current subtile lock has completed. + // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); + } + ); + } + + // Exit of the subtile load loop. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [] (auto& callbacks) { + callbacks.end(); + } + ); + } + }; + + // Producer load callbacks factory + // All operations must redefine this, but most can just dispatch to the base impl + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return transform_apply(ops, + [&] (auto& op) { + return op.get_producer_load_callbacks(args); + }, + [] (auto&&... callbacks) { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ProducerLoadCallbacks{callbacks_tuple}; + } + ); + } + + // + // Consumer store callbacks, called by the epilogue store warps. + // All operations must redefine this, with optional inheritance from this empty implementation. + // + template + struct ConsumerStoreCallbacks { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of subtile store loop. Gmem broadcasts usually performed here. + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [] (auto& callbacks) { + callbacks.begin(); + } + ); + } + + // Start of subtile store iteration. Smem broadcasts usually performed here. + // Upon entry, all producer loads for this subtile are completed and visible. + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); + } + ); + } + + // Perform the fused elementwise computation + template + CUTLASS_DEVICE auto // returns an Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) // depends on the N-naryness of the op + = delete; // Must be implemented for each operation + + // After visit call. Smem reductions usually performed here + // reduction_buffer is an arbitrary smem tensor that can be used for workspace + // It is each nodes reponsibility to assert that this buffer is sufficiently sized + // and to ensure that this buffer is no longer needed upon callback exit + // i.e. results are synchronized and no longer in the reduction buffer + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration); + } + ); + } + + // After reduce call, before smem async fence. Smem stores usually performed here. + // Upon exit, all smem stores for TMA must have been issued + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); + } + ); + } + + // After smem async fence, before TMA store commit. Aux stores usually performed here + // Upon exit, all TMA stores for this subtile must have been issued + // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores + // other gmem stores can be placed in the reduce or postreduce entry points + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); + } + ); + } + + // Exit of subtile store loop. Gmem reductions usually performed here. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.end(); + } + ); + } + }; + + // Consumer store callbacks factory + // All operations must redefine this + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return transform_apply(ops, + [&] (auto& op) { + return op.template get_consumer_store_callbacks(args); + }, + [] (auto&&... callbacks) { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ConsumerStoreCallbacks{callbacks_tuple}; + } + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Convenience aliases +using EmptyProducerLoadCallbacks = Sm90VisitorImpl<>::ProducerLoadCallbacks>; +using EmptyConsumerStoreCallbacks = Sm90VisitorImpl<>::ConsumerStoreCallbacks>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tree visitor +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sm90TreeVisitor : Sm90VisitorImpl { + + using Impl = Sm90VisitorImpl; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(ChildOps); + return cute::detail::tapply(callbacks_tuple, + [&] (auto& child_callbacks) { + return child_callbacks.visit(frg_acc, epi_v, epi_m, epi_n); // child ops must be nullary (e.g. loads, trees) + }, + [&] (auto&&... frg_inputs) { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + make_seq{} // restrict the transform to R-1 child ops, apply is for node op + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// DAG visitors +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Most DAG fusions can be represented as a set of output trees with a common input tree +// The common input is first evaluated, then the result is passed as the acc fragment to the output trees +template +struct Sm90SplitTreeVisitor : Sm90VisitorImpl { + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_input = get<0>(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + constexpr int Rm2 = sizeof...(AuxOutTrees); + cute::for_each(make_seq{}, // restrict the sequence to aux out trees + [&] (auto I) { + get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + ); + + return get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // deducing the output type for all the nodes is tricky so we just convert them all to a common type + // if multiple compute types are needed then split into multiple subgraphs grouped by type + class ElementCompute, + class EdgeTuple, // tuple of int_sequence, each sequence is the children indices (indexed by topological order) for each node + class... Ops // in topological order, last op is the output. EdgeTuple must match this order +> +struct Sm90TopologicalVisitor : Sm90VisitorImpl { + static_assert(is_static_v); + static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops)); + static_assert(sizeof...(Ops) > 1); + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(Ops) - 1; + auto frg_compute_tuple = cute::repeat(Array{}); + + return cute::detail::tapply(EdgeTuple{}, callbacks_tuple, frg_compute_tuple, + // Visit the first R-1 ops in topological order + [&] (auto&& edge_seq, auto& callbacks, auto& frg_compute) { + frg_compute = cute::detail::apply(frg_compute_tuple, + // Compute the current op with children inputs + [&] (auto const&... frg_inputs) { + auto frg_output = callbacks.visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + using ElementOutput = typename decltype(frg_output)::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + + return convert_output(frg_output); + }, + // Get inputs in the sequence given by the children indices of the current op + edge_seq + ); + return frg_compute; // unused + }, + // Visit the last op + [&] (auto const&...ops) { + return cute::detail::apply(frg_compute_tuple, + // Compute the last op with children inputs + [&] (auto const&... frg_inputs) { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + // Get inputs in the sequence given by the children indices of the last op + get(EdgeTuple{}) + ); + }, + // Transform to visit R-1 ops, apply to visit last op + make_seq{} + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Base specializations so we can have standard layout params and simple aggregate initializers +namespace detail { + +template +struct Sm90VisitorImplBase { + + // Retain tuple for SharedStorage because empty structs have 1B alignment + // tuples use multiple inheritance, avoids this problem + using SharedStorage = tuple< + typename Op0::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + }; + + struct Params { + typename Op0::Params op_0; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, workspace) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage, + typename Op3::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + typename Op3::Arguments op_3; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + typename Op3::Params op_3; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); + size_t op_2_workspace_size = Op2::get_workspace_size(problem_shape, args.op_2); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; + uint8_t* op_3_workspace = op_2_workspace + op_2_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace), + Op3::to_underlying_arguments(problem_shape, args.op_3, op_3_workspace) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op3::initialize_workspace(problem_shape, args.op_3, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)), + Op3(params.op_3, get<3>(shared_storage)) + }) {} + + tuple ops; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/punica_kernels/include/cutlass/cutlass/epilogue/thread/activation.h b/server/punica_kernels/include/cutlass/cutlass/epilogue/thread/activation.h new file mode 100644 index 00000000..1a226a75 --- /dev/null +++ b/server/punica_kernels/include/cutlass/cutlass/epilogue/thread/activation.h @@ -0,0 +1,702 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This extends the contents of cutlass/functional.h with frequently used activation functions. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/constants.h" +#include "cutlass/complex.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Identity operator +template +struct Identity { + static const bool kIsHeavy = false; + + CUTLASS_HOST_DEVICE + T operator()(T value) const { + return value; + } +}; + +template +struct Identity > { + CUTLASS_HOST_DEVICE + Array operator()(Array value) const { + return value; + } +}; + +/// Scale operator +template +struct Scale { + struct Arguments { + using scale_type = T; + T scale = T(1); + }; + + CUTLASS_HOST_DEVICE + T operator()(T value, T scale) const { + multiplies mul; + return mul(scale, value); + } + + CUTLASS_HOST_DEVICE + T operator()(T value, Arguments args = Arguments()) const { + return this->operator()(value, args.scale); + } +}; + +template +struct Scale> { + using Arguments = typename Scale::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array values, T scale) const { + multiplies> mul; + return mul(scale, values); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array values, Arguments args = Arguments()) const { + return this->operator()(values, args.scale); + } +}; + +/// Specialization to compose other activations with a defined unary operator +/// e.g. Scale> +template